/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // @author Yurii Shyrma, created on 27.11.2018 // #include "../aggregates.h" namespace functions { namespace aggregate { /////////////////////////////////////////////////////////////////////// template template __device__ void AggregatedFunction::execCuda(X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments) { OpClass::executeAggregateCuda(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments); } /////////////////////////////////////////////////////////////////////// template __device__ void AggregatedFunction::execCuda(int opNum, X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments) { DISPATCH_BY_OPNUM_T(execCuda, PARAMS(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), AGGREGATE_OPS); } /////////////////////////////////////////////////////////////////////// template __global__ static void execAggregateKernel(int opNum, void **varguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, void *vrealArguments, int numRealArguments) { auto arguments = reinterpret_cast(varguments); auto realArguments = reinterpret_cast(vrealArguments); functions::aggregate::AggregatedFunction::execCuda(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments); } /////////////////////////////////////////////////////////////////////// template __host__ void AggregatedFunction::aggregateKernelGeneric(dim3& launchDims, cudaStream_t *stream, int opNum, void **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, void *realArguments, int numRealArguments) { execAggregateKernel<<>>(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments); nd4j::DebugHelper::checkErrorCode(stream, "aggregateKernelGeneric(...) failed"); } /////////////////////////////////////////////////////////////////////// template __device__ void AggregatedFunction::aggregateBatch(int opNum, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments) { nd4j::PointersHelper helper(ptrToArguments, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals); // TODO: we probably should lift this restriction __shared__ int *intArrays[32]; __shared__ X **arguments; __shared__ Nd4jLong **shapes; __shared__ int *idxArg; __shared__ X *realArg; for(int r = blockIdx.x; r < numAggregates; r += gridDim.x) { if (threadIdx.x == 0) { arguments = helper.getArguments(r); shapes = helper.getShapeArguments(r); idxArg = helper.getIndexArguments(r); realArg = helper.getRealArguments(r); } // we fill intArrays param in parallel within block if (threadIdx.x < 32 && threadIdx.x < maxIntArrays) { intArrays[threadIdx.x] = helper.getIntArrayArguments(r, threadIdx.x); } __syncthreads(); functions::aggregate::AggregatedFunction::execCuda(opNum, arguments, helper.getNumArguments(r), shapes, helper.getNumShapeArguments(r), idxArg, helper.getNumIndexArguments(r), intArrays, helper.getNumIntArrayArguments(r), realArg, helper.getNumRealArguments(r)); } } /////////////////////////////////////////////////////////////////////// template __global__ static void execAggregateBatch(int opNum, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments) { functions::aggregate::AggregatedFunction::aggregateBatch(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments); } /////////////////////////////////////////////////////////////////////// template __host__ void AggregatedFunction::aggregateBatchKernelGeneric(dim3& launchDims, cudaStream_t *stream, int opNum, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments) { execAggregateBatch<<>>(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments); nd4j::DebugHelper::checkErrorCode(stream, "aggregateBatchKernel(...) failed"); } BUILD_SINGLE_TEMPLATE(template class AggregatedFunction, , FLOAT_TYPES); } }