/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver119 on 07.10.2017. // Modified by GS on 3/9/2018 // #include #include #include #include namespace sd { namespace blas { template void* transpose(int orderSource, int orderTarget, int rows, int cols, void *vsource) { auto ret = new T[rows * cols]; auto source = reinterpret_cast(vsource); // handle transpose in parallel auto func = PRAGMA_THREADS_FOR { for (auto r = start; r < stop; r++) { for (int c = 0; c < cols; c++) { int zIdx = orderTarget == CblasRowMajor ? linearIndexC(rows, cols, r, c) : linearIndexF(rows, cols, r, c); int xIdx = orderSource == CblasColMajor ? linearIndexF(rows, cols, r, c) : linearIndexC(rows, cols, r, c); ret[zIdx] = source[xIdx]; } } }; samediff::Threads::parallel_for(func, 0, rows); return ret; } template void GEMM::op(int Order, int TransA, int TransB, int M, int N, int K, double alpha, void *vA, int lda, void *vB, int ldb, double beta, void *vC, int ldc) { auto A = reinterpret_cast(vA); auto B = reinterpret_cast(vB); auto C = reinterpret_cast(vC); bool transAFlag = TransA == CblasTrans; bool transBFlag = TransB == CblasTrans; if (beta == 0.0) { Z z = 0.f; int length = M*N; if (length <= Environment::getInstance().elementwiseThreshold()) { for (int r = 0; r < length; r++) C[r] = z; } else { auto func = PRAGMA_THREADS_FOR { for (auto r = start; r < stop; r++) C[r] = z; }; samediff::Threads::parallel_for(func, 0, length); } } auto func = PRAGMA_THREADS_FOR_2D { for (auto r = start_x; r < stop_x; r += inc_x) { for (auto c = start_y; c < stop_y; c += inc_y) { int zIdx = linearIndexF(M, N, r, c); Z dot = static_cast(0.0f); if (alpha != 0.0) { int bIdx; // = linearIndexF(K, N, 0, c); int aIdx; for (int k = 0; k < K; k++) { aIdx = (transAFlag ? linearIndexC(M, K, r, k) : linearIndexF(M, K, r, k)); bIdx = (transBFlag ? linearIndexC(K, N, k, c) : linearIndexF(K, N, k, c)); dot += static_cast(alpha) * static_cast(A[aIdx]) * static_cast(B[bIdx]);//A[aIdx]sd::math::nd4j_dot(aX, bX, K) * alpha; } } if (beta != 0.0) { C[zIdx] = static_cast(dot + static_cast(beta) * C[zIdx]); } else { C[zIdx] = static_cast(dot); } } } }; samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1); } template void GEMV::op(int TRANS, int M, int N, double alpha, void * vX, int lda, void* vY, int incx, double beta, void* vZ, int incy ) { auto x = reinterpret_cast(vX); auto y = reinterpret_cast(vY); auto z = reinterpret_cast(vZ); auto aT = TRANS == CblasTrans ? reinterpret_cast(sd::blas::transpose(CblasColMajor, CblasRowMajor, M, N, reinterpret_cast(x))) : x; auto func = PRAGMA_THREADS_FOR { for (auto r = start; r < stop; r++) { int aIdx = linearIndexC(M, N, r, 0); auto aX = aT + aIdx; auto dot = sd::math::nd4j_dot(aX, y, lda) * static_cast(alpha); z[r] = beta == 0.0f ? dot : dot + static_cast(beta) * z[r]; } }; samediff::Threads::parallel_for(func, 0, M); if (TRANS == CblasTrans) delete[] aT; } //BUILD_TRIPLE_TEMPLATE(template class GEMV, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template class GEMM, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } }