cavis/libnd4j/include/ops/impl/gemm.cpp

151 lines
5.8 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 07.10.2017.
// Modified by GS <sgazeos@gmail.com> on 3/9/2018
//
#include <ops/gemm.h>
2019-06-06 14:21:15 +02:00
#include <types/types.h>
#include <system/Environment.h>
#include <execution/Threads.h>
2019-06-06 14:21:15 +02:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace blas {
template <typename T>
void* transpose(int orderSource, int orderTarget, int rows, int cols, void *vsource) {
auto ret = new T[rows * cols];
auto source = reinterpret_cast<T *>(vsource);
// handle transpose in parallel
auto func = PRAGMA_THREADS_FOR {
for (auto r = start; r < stop; r++) {
for (int c = 0; c < cols; c++) {
int zIdx = orderTarget == CblasRowMajor ? linearIndexC(rows, cols, r, c) : linearIndexF(rows, cols, r, c);
int xIdx = orderSource == CblasColMajor ? linearIndexF(rows, cols, r, c) : linearIndexC(rows, cols, r, c);
2019-06-06 14:21:15 +02:00
ret[zIdx] = source[xIdx];
}
2019-06-06 14:21:15 +02:00
}
};
sd::Threads::parallel_for(func, 0, rows);
2019-06-06 14:21:15 +02:00
return ret;
}
template <typename X, typename Y, typename Z>
void GEMM<X, Y, Z>::op(int Order, int TransA, int TransB,
int M, int N, int K,
double alpha,
void *vA, int lda,
void *vB, int ldb,
double beta,
void *vC, int ldc) {
auto A = reinterpret_cast<X *>(vA);
auto B = reinterpret_cast<Y *>(vB);
auto C = reinterpret_cast<Z *>(vC);
bool transAFlag = TransA == CblasTrans;
bool transBFlag = TransB == CblasTrans;
if (beta == 0.0) {
Z z = 0.f;
2019-06-06 14:21:15 +02:00
int length = M*N;
if (length <= Environment::getInstance()->elementwiseThreshold()) {
for (int r = 0; r < length; r++)
C[r] = z;
2019-06-06 14:21:15 +02:00
} else {
auto func = PRAGMA_THREADS_FOR {
for (auto r = start; r < stop; r++)
C[r] = z;
};
sd::Threads::parallel_for(func, 0, length);
2019-06-06 14:21:15 +02:00
}
}
auto func = PRAGMA_THREADS_FOR_2D {
for (auto r = start_x; r < stop_x; r += inc_x) {
for (auto c = start_y; c < stop_y; c += inc_y) {
int zIdx = linearIndexF(M, N, r, c);
2019-06-06 14:21:15 +02:00
Z dot = static_cast<Z>(0.0f);
2019-06-06 14:21:15 +02:00
if (alpha != 0.0) {
int bIdx; // = linearIndexF(K, N, 0, c);
int aIdx;
2019-06-06 14:21:15 +02:00
for (int k = 0; k < K; k++) {
aIdx = (transAFlag ? linearIndexC(M, K, r, k) : linearIndexF(M, K, r, k));
bIdx = (transBFlag ? linearIndexC(K, N, k, c) : linearIndexF(K, N, k, c));
dot += static_cast<Z>(alpha) * static_cast<Z>(A[aIdx]) * static_cast<Z>(B[bIdx]);//A[aIdx]sd::math::nd4j_dot<T>(aX, bX, K) * alpha;
}
2019-06-06 14:21:15 +02:00
}
if (beta != 0.0) {
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
C[zIdx] = static_cast<Z>(dot + static_cast<Z>(beta) * C[zIdx]);
} else {
C[zIdx] = static_cast<Z>(dot);
}
2019-06-06 14:21:15 +02:00
}
}
};
sd::Threads::parallel_for(func, 0, M, 1, 0, N, 1);
2019-06-06 14:21:15 +02:00
}
template<typename X, typename Y, typename Z>
void GEMV<X, Y, Z>::op(int TRANS, int M, int N,
double alpha,
void * vX,
int lda,
void* vY,
int incx,
double beta,
void* vZ,
int incy ) {
auto x = reinterpret_cast<X *>(vX);
auto y = reinterpret_cast<Y *>(vY);
auto z = reinterpret_cast<Z *>(vZ);
auto aT = TRANS == CblasTrans ? reinterpret_cast<X *>(sd::blas::transpose<X>(CblasColMajor, CblasRowMajor, M, N, reinterpret_cast<void *>(x))) : x;
2019-06-06 14:21:15 +02:00
auto func = PRAGMA_THREADS_FOR {
for (auto r = start; r < stop; r++) {
int aIdx = linearIndexC(M, N, r, 0);
auto aX = aT + aIdx;
2019-06-06 14:21:15 +02:00
auto dot = sd::math::nd4j_dot<X, Y, Z>(aX, y, lda) * static_cast<Z>(alpha);
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
z[r] = beta == 0.0f ? dot : dot + static_cast<Z>(beta) * z[r];
}
};
sd::Threads::parallel_for(func, 0, M);
2019-06-06 14:21:15 +02:00
if (TRANS == CblasTrans)
delete[] aT;
}
//BUILD_TRIPLE_TEMPLATE(template class GEMV, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template class GEMM, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
2019-06-06 14:21:15 +02:00
}
}