Alex Black 1170827c18 Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)

* Modified strided_slice op to properly work with empty-like shapes.

* Fixed test for reduce_mean with empty-like input.

* [WIP] Last merge (#15)

* correct logsoftmax looss (#2)

* Small SameDiff listener fix (#4)

* Various fixes (#6)

* #7839 Fix for asXMatrix and tests

* #7866 EmbeddingSequenceLayer dtype fix + test

* #7856 SameDiff save/load stream methods

* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration

* EvaluationBinary 3d/4d

* More evaluation 3d/4d tests

* #7847 Evaluation empty checks

* Small test ifx

* #7848 Fix median edge case

* Improve DL4J samediff layer tests

* [WIP] FastText wrapper implemented (#8)

* FastText implemented

* Some fixes

* Fix shapes for wordsNearest

* Validation of input vectors

* Fixes

* Fixed test

* Thread tagged

* Some tweaks

* setContextClassLoader for DeallocatorServiceThread

* Numpy format tests (#1)

* Various fixes (#11)

* #7852 SameDiff gather fix

* #7892 SameDiff placeholder to constant conversion

* #7890 validate input rank for MLN/CG init methods

* Fix broken permute shape calculation

* Permute and gather fixes

* Tests

* #7850 LogSumExp fix + test

* Handful of test fixes

* Empty arrays with non-scalar shapes (#10)

* minor rearrangements for lambdas

* empty tensors with non-scalar shapes

* numpy empty tensors with non-scalar shapes

* few more empty tweaks

* Small fixes

* conv3d signature update

* micro fix in batchnorm mkldnn

* Import fixes

* Fix

* MKL-DNN update

* Small fill fix

* fill with empty input + test

* Fixes

* Small error improvement

* Fix

* one special test

* couple of fixes for lstm

* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone

* Fixes

* FP16

* Unsigned

* BFloat16

* Fill op - empty tweaks

* - couple of fixes for empty arrays construction
- stack updated

* strided slice fix

* one transform test

* provide method for reducing shapeInfo in case of input array is empty

* Fixed reduceAlongDimensions to use empty input properly.

* couple of broadcast tests

* couple of tests broadcast tests + tweak to make them pass

* add check of non-empty to methods producing sub-arrays

* Fixed reshapeC with zeros in shape.

* complete empty check in reduce_... legacy ops

* Concat and cumsum/prod

* Tweak to empty shape inference on import

* add empty check to the rest of reduce legacy ops

* one more test

* correct typo in evalReduceShapeInfoEmpty

* Added tests for reduce_* ops to tests with zero shapes.

* few more tests for empty reductions

* Fixed strided_slice op with empty case and tests.

* one more empty reduction test

* Fixed strided_slice test.

* add empty check to NDArray::reshapei

* infOrMax

* empty min/max with infinity tests

* made unstack working correctly with empty arrays

* few IndexReduce tests + tweaks for empty shapes

* add test for empty concat

* few tests fixed

* Validation fix for reductions on empty shapes

* Reverse fix

* Reduction shape calc fixes

* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs

* Range fix

* - NDArray constructor updated for scalars/empty arrays
- few tests fixed

* More fixes

* Empty creator fixes

* concat fix

* concat fix

* TF import tests: allow 'both all NaN' and 'both all inf' to pass

* Slice, zero fraction, and reshape fixes

* transpose, gather

* Zero fraction

* scalar cast fix

* Empty reduction axis support

* few more tests fixed

* Fixed input checks conforming with TF for concat op and tests.

* few tests fixed

* matmul scalar shape fix

* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.

* broadcast bool fix

* few more tests

* few more tests

* correct evalReduceShapeInfoEmpty

* argmax/argmin + tests

* one more empty edge case + one more test

* argmax/argmin/realdiv_bp tweaks

* empty reshape test + fix

* Helper fixes

* Small fixes

* Gather test fix

* Gather test fix

* Small fixes

* reduce scalar zero values

* scalar mean workaround

* Remove debug code

* along dim mean workaround

* one more test

* - equalsTo() tweak for empty arrays
- one more test

* broadcast tweaks

* [WIP] Fixing outstanding issues for NLP (#9)

* Avoid using not-inited objects

* Test fixed.

* Redundant method avoided for models like FastText

* KMeans++ implementation

* KMeans++ implementation

* Disable parallel execution

* KMeans++

* Tests

* Dev branch merge (#16)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Fix some issues on master (#17)

* Fix DataVec test issue

* Fix issue with dl4j SameDiff output layer

* Dtype fix for lambda layers

* #7912 BertIterator dtype fix (use float32 not global default)

* [WIP] Next set of CUDA stuff (#7)

New CUDA implementations and improvements

* bad file

* Dev branch master merge (#23)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* SameDiff ops, TF import and fixes (#24)

* CheckNumerics tests + fixes + misc fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fake quant

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FakeQuantWithMinMaxArgs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* CheckNumerics fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Exception tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for out of scope stack allocated var use

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignores

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignore for known failing test (already logged issue)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Merge upstream to fork (#25)

* Add thousand-separator commas to TotalParams (#7915)

* Add thousand-separator commas to TotalParams

The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.

* Add thousand-separator commas to MultiLayerNetwork

Corresponding change to MultiLayerNetwork

Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>

* Update contributing and issue/PR templates (#7934)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix link to AdaDelta paper (#7942)

Fix link to AdaDelta paper hosted on matthewzeiler.com

Signed-off-by: Jxtps

* Fixes, and ignores for known/logged failing issues (#7943)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff + DL4J/SameDiff: Multiple fixes (#28)

* #7919 HDF5 attribute buffer length fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7909 Arbiter constructor exception ux improvements

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7925 RNN output layer length checks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Add listener for validating inputs are not incorrectly modified

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Integrate NonInplaceValidationListener into tests

* #7844 DL4J SameDiff fixes for variable minibatch size

* DL4J SameDiff fixes - ensure gradient for input placeholder is available

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweaks to ExternalErrorsFunction - use placeholders, make more robust

* Another fix

* More fixes

* More SameDiff/DL4J fixes

* Scope out scalar array creation in BaseScalarOp

* Remove debug code

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] Final dev branch merge (#29)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* [WIP] Multiple dataset iterators (#27)

* Splitting dataset into arbitrary number

* Fixes

* Multiple split of iterator

* Test

* Test

* Some fixes

* signature change

* one more tweak

Signed-off-by: raver119 <raver119@gmail.com>

* one more test for sequential use of DataSetIteratorSplitter

Signed-off-by: raver119 <raver119@gmail.com>

* Fixes

* Fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* minor test fix

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* couple of assertions tweaked

Signed-off-by: raver119 <raver119@gmail.com>

* MDS splitter test :/

Signed-off-by: raver119 <raver119@gmail.com>

* Minor refactoring

* Multi dataset

* Some fixes

* More tests

* Small number of test fixes/improvements (failures on CI) (#31)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] More CUDA stuff (#26)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* LRN BP CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* less memory

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed bug with crop_and_resize op helper.

* get rid of unnecessary index-calculation dunction

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed sort with nth_element cuda-based helper.

* Refactored nth_element.

* Refactored nth_element op and tests.

* Modified usage of dim array with sortTad routine.

* Refactored main routine of helper for non_max_image_suppression op.

* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.

* fix vol2col cuda kernel

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* topK concept

Signed-off-by: raver119 <raver119@gmail.com>

* unsorted topK with scanWitdh of 1

Signed-off-by: raver119 <raver119@gmail.com>

* correct vol2col tests

* sorted/unsorted topK

Signed-off-by: raver119 <raver119@gmail.com>

* implementation and fixing col2im/col2vol

* Corrected usage flags with input/output with reverse op.

* dup is const now

Signed-off-by: raver119 <raver119@gmail.com>

* percentile op

Signed-off-by: raver119 <raver119@gmail.com>

* group tests for mapool2d

Signed-off-by: Yurii <yurii@skymind.io>

* special test for george

Signed-off-by: raver119 <raver119@gmail.com>

* less threads for sortTad

Signed-off-by: raver119 <raver119@gmail.com>

* provide conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* remove auther in sort tad kernel code

Signed-off-by: Yurii <yurii@skymind.io>

* provide depthwise_conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* - max_pooling_with_argmax
- null check for special use

Signed-off-by: raver119 <raver119@gmail.com>

* dts cuda

Signed-off-by: raver119 <raver119@gmail.com>

* provide sconv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* std cuda

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op to conform TF implementation.

* Improved suppression helper.

* provide pooling3d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* more of minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* (bi)dynamic_rnn

Signed-off-by: raver119 <raver119@gmail.com>

* templates init order

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op.

* Added cuda kernel for non_max_suppression.

* CPU sort by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value tests

Signed-off-by: raver119 <raver119@gmail.com>

* Eliminate compiler error with cuda implementation.

* - repaired gradCheck in cuda
- provide conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* missed signature

Signed-off-by: raver119 <raver119@gmail.com>

* provide depthwise_conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of lup helper with cuda kernel. Initial commit.

* further work on backprops for convolutions

Signed-off-by: Yurii <yurii@skymind.io>

* CUDA linear sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA tad sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* start providing of backprop for pooling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* Added atomicAdd for bool datatype.

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition scalar CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* important comment

Signed-off-by: raver119 <raver119@gmail.com>

* fix pooling2d/3d backprop helpers

Signed-off-by: Yurii <yurii@skymind.io>

* Added non-linear test with dynamic_partition.

* Improved test for dynamic_partition.

* dynamic_partition TAD concept

Signed-off-by: raver119 <raver119@gmail.com>

* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix

Signed-off-by: raver119 <raver119@gmail.com>

* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* dynamic_stitch CUDA vector case

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case impl

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for dynamic_stitch 3D-4D cases.

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed type check for dynamic stitch.

* min/max bp

Signed-off-by: raver119 <raver119@gmail.com>

* rewrite code for upsampling2d/3d cpu

Signed-off-by: Yurii <yurii@skymind.io>

* reduce min/max/norm_max bp

Signed-off-by: raver119 <raver119@gmail.com>

* lup implementation. Additional enhancements.

* provide code for upsamling2d/3d backprop

Signed-off-by: Yurii <yurii@skymind.io>

* weightedCrossEntropyWithLogits

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed template math atomicMul for 64bit ints.

* Refactored dynamic_partition_bp op.

* inverseBroadcast fix

Signed-off-by: raver119 <raver119@gmail.com>

* DynamicPartitionBP test datatype fixed.

* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA

Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 18:37:04 +03:00

177 lines
6.8 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <exceptions/cuda_exception.h>
#include <cublas_v2.h>
#include <specials_cuda.h>
#include <op_boilerplate.h>
#include <types/float16.h>
#include <ops/declarable/helpers/batched_gemm.h>
#include <PointersManager.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////
// bsxMXK x bSxKxN = bSxMxN
void bgemm(const std::vector<NDArray*>& vA, const std::vector<NDArray*>& vB, std::vector<NDArray*>& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) {
const auto bS = vA.size(); // batch size
std::vector<NDArray*> pA(bS), pB(bS), pC(bS);
std::vector<NDArray*> toDelete;
for(int i = 0; i < bS; ++i) {
if(vA[i]->ews() != 1) {
pA[i] = vA[i]->dup('f');
toDelete.emplace_back(pA[i]);
}
else
pA[i] = vA[i];
if(vB[i]->ews() != 1) {
pB[i] = vB[i]->dup('f');
toDelete.emplace_back(pB[i]);
}
else
pB[i] = vB[i];
if(vC[i]->ews() != 1) {
pC[i] = vC[i]->dup('f');
toDelete.emplace_back(pC[i]);
}
else
pC[i] = vC[i];
if(pC[i]->ordering() != 'f') {
auto temp = pA[i];
pA[i] = new NDArray(pB[i]->permute({1,0}));
pB[i] = new NDArray(temp ->permute({1,0}));
pC[i] = new NDArray(pC[i]->permute({1,0}));
toDelete.push_back(pA[i]);
toDelete.push_back(pB[i]);
toDelete.push_back(pC[i]);
M = pA[i]->sizeAt(0);
K = pA[i]->sizeAt(1);
N = pB[i]->sizeAt(1);
}
NDArray::prepareSpecialUse ({pC[i]}, {pA[i], pB[i]});
NDArray::registerSpecialUse({pC[i]}, {pA[i], pB[i]});
}
NDArray::prepareSpecialUse ({}, {alphas, betas});
NDArray::registerSpecialUse({}, {alphas, betas});
std::vector<void*> pAbuffs(bS), pBbuffs(bS), pCbuffs(bS);
for(int i = 0; i < bS; ++i) {
pAbuffs[i] = pA[i]->getSpecialBuffer();
pBbuffs[i] = pB[i]->getSpecialBuffer();
pCbuffs[i] = pC[i]->getSpecialBuffer();
}
nd4j::LaunchContext* context = vA[0]->getContext();
PointersManager manager(context, "helpers::bgemm cuda");
const void** aBuffers = reinterpret_cast<const void**>(manager.replicatePointer(pAbuffs.data(), bS * sizeof(void*)));
const void** bBuffers = reinterpret_cast<const void**>(manager.replicatePointer(pBbuffs.data(), bS * sizeof(void*)));
void** cBuffers = reinterpret_cast<void**>(manager.replicatePointer(pCbuffs.data(), bS * sizeof(void*)));
// const auto aOrder = pA->ordering();
// const auto bOrder = pB->ordering();
// const bool transA = aOrder != 'f';
// const bool transB = bOrder != 'f';
const cublasOperation_t transAblas = transA == 112 ? CUBLAS_OP_T : CUBLAS_OP_N;
const cublasOperation_t transBblas = transB == 112 ? CUBLAS_OP_T : CUBLAS_OP_N;
// const int lda = aOrder == 'f' ? M : K;
// const int ldb = bOrder == 'f' ? K : N;
// const int ldc = M; // cOrder == 'f' ? M : N;
const auto aType = pA[0]->dataType();
const auto bType = pB[0]->dataType();
const auto cType = pC[0]->dataType();
auto handle = reinterpret_cast<cublasHandle_t*>(context->getCublasHandle());
auto stream = context->getCudaStream();
auto status = cublasSetStream_v2(*handle, *stream);
if (status != CUBLAS_STATUS_SUCCESS)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC);
// choose appropriate cuda gemm api depending on data types
if(ABC && aType == DataType::DOUBLE) {
double alpha = alphas->e<double>(0);
double beta = betas->e<double>(0);
status = cublasDgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const double**)aBuffers, lda, (const double**)bBuffers, ldb, &beta, (double**)cBuffers, ldc, bS);
}
else if(ABC && aType == DataType::FLOAT32) {
float alpha = alphas->e<float>(0);
float beta = betas->e<float>(0);
status = cublasSgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const float**)aBuffers, lda, (const float**)bBuffers, ldb, &beta, (float**)cBuffers, ldc, bS);
}
else if(ABC && aType == DataType::HALF) {
__half alpha = alphas->e<float>(0);
__half beta = betas->e<float>(0);
status = cublasHgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const __half**)aBuffers, lda, (const __half**)bBuffers, ldb, &beta, (__half**)cBuffers, ldc, bS);
}
else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32) {
float alpha = alphas->e<float>(0);
float beta = betas->e<float>(0);
status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_8I, lda, bBuffers, CUDA_R_8I, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
}
else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32) {
float alpha = alphas->e<float>(0);
float beta = betas->e<float>(0);
status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_16F, lda, bBuffers, CUDA_R_16F, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
}
else
throw std::runtime_error("batched gemm cuda: this mode is not implemented yet !");
if (status != CUBLAS_STATUS_SUCCESS)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
for(int i = 0; i < bS; ++i)
if(vC[i]->ews() != 1)
vC[i]->assign(pC[i]);
for(int i = toDelete.size() - 1; i >= 0; --i)
delete toDelete[i];
}
}
}
}