146 lines
7.3 KiB
Plaintext
146 lines
7.3 KiB
Plaintext
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
// @author Yurii Shyrma, created on 27.11.2018
|
|
//
|
|
|
|
#include "../aggregates.h"
|
|
|
|
namespace functions {
|
|
namespace aggregate {
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
template <typename X>
|
|
template<typename OpClass>
|
|
__device__ void AggregatedFunction<X>::execCuda(X **arguments, int numArguments,
|
|
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
int *indexArguments, int numIndexArguments,
|
|
int **intArrays, int numIntArrays,
|
|
X *realArguments, int numRealArguments) {
|
|
|
|
OpClass::executeAggregateCuda(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
template <typename X>
|
|
__device__ void AggregatedFunction<X>::execCuda(int opNum,
|
|
X **arguments, int numArguments,
|
|
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
int *indexArguments, int numIndexArguments,
|
|
int **intArrays, int numIntArrays,
|
|
X *realArguments, int numRealArguments) {
|
|
|
|
DISPATCH_BY_OPNUM_T(execCuda, PARAMS(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), AGGREGATE_OPS);
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
template <typename X>
|
|
__global__ static void execAggregateKernel(int opNum,
|
|
void **varguments, int numArguments,
|
|
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
int *indexArguments, int numIndexArguments,
|
|
int **intArrays, int numIntArrays,
|
|
void *vrealArguments, int numRealArguments) {
|
|
|
|
auto arguments = reinterpret_cast<X**>(varguments);
|
|
auto realArguments = reinterpret_cast<X*>(vrealArguments);
|
|
functions::aggregate::AggregatedFunction<X>::execCuda(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
template <typename X>
|
|
__host__ void AggregatedFunction<X>::aggregateKernelGeneric(dim3& launchDims, cudaStream_t *stream,
|
|
int opNum,
|
|
void **arguments, int numArguments,
|
|
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
int *indexArguments, int numIndexArguments,
|
|
int **intArrays, int numIntArrays,
|
|
void *realArguments, int numRealArguments) {
|
|
|
|
execAggregateKernel<X><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
nd4j::DebugHelper::checkErrorCode(stream, "aggregateKernelGeneric(...) failed");
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
template <typename X>
|
|
__device__ void AggregatedFunction<X>::aggregateBatch(int opNum, int numAggregates,
|
|
int maxArgs, int maxShapes,
|
|
int maxIntArrays, int maxIntArraySize,
|
|
int maxIdx, int maxReals,
|
|
void *ptrToArguments) {
|
|
|
|
nd4j::PointersHelper<X> helper(ptrToArguments, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals);
|
|
|
|
// TODO: we probably should lift this restriction
|
|
__shared__ int *intArrays[32];
|
|
|
|
__shared__ X **arguments;
|
|
__shared__ Nd4jLong **shapes;
|
|
__shared__ int *idxArg;
|
|
__shared__ X *realArg;
|
|
|
|
for(int r = blockIdx.x; r < numAggregates; r += gridDim.x) {
|
|
if (threadIdx.x == 0) {
|
|
arguments = helper.getArguments(r);
|
|
shapes = helper.getShapeArguments(r);
|
|
idxArg = helper.getIndexArguments(r);
|
|
realArg = helper.getRealArguments(r);
|
|
}
|
|
|
|
// we fill intArrays param in parallel within block
|
|
if (threadIdx.x < 32 && threadIdx.x < maxIntArrays) {
|
|
intArrays[threadIdx.x] = helper.getIntArrayArguments(r, threadIdx.x);
|
|
}
|
|
__syncthreads();
|
|
|
|
functions::aggregate::AggregatedFunction<X>::execCuda(opNum, arguments, helper.getNumArguments(r), shapes, helper.getNumShapeArguments(r), idxArg, helper.getNumIndexArguments(r), intArrays, helper.getNumIntArrayArguments(r), realArg, helper.getNumRealArguments(r));
|
|
}
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
template <typename X>
|
|
__global__ static void execAggregateBatch(int opNum, int numAggregates,
|
|
int maxArgs, int maxShapes,
|
|
int maxIntArrays, int maxIntArraySize,
|
|
int maxIdx, int maxReals,
|
|
void *ptrToArguments) {
|
|
|
|
functions::aggregate::AggregatedFunction<X>::aggregateBatch(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments);
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
template <typename X>
|
|
__host__ void AggregatedFunction<X>::aggregateBatchKernelGeneric(dim3& launchDims, cudaStream_t *stream,
|
|
int opNum, int numAggregates,
|
|
int maxArgs, int maxShapes,
|
|
int maxIntArrays, int maxIntArraySize,
|
|
int maxIdx, int maxReals,
|
|
void *ptrToArguments) {
|
|
|
|
execAggregateBatch<X><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments);
|
|
nd4j::DebugHelper::checkErrorCode(stream, "aggregateBatchKernel(...) failed");
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
BUILD_SINGLE_TEMPLATE(template class AggregatedFunction, , FLOAT_TYPES);
|
|
}
|
|
}
|