/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include #include #include #include #include using namespace simdOps; template __global__ void transformAnySimple(void *x, Nd4jLong *xShapeInfo, int xRank, void *params, void *z, Nd4jLong *zShapeInfo, int zRank, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { functions::transform::TransformAny::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); } namespace functions { namespace transform { template _CUDA_H void TransformAny::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShape, int xRank, void *extraParams, void *z, Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_ANY_OPS); DEBUG_KERNEL(stream, opNum); } template template __device__ void TransformAny::transformCuda(void *vx, Nd4jLong *xShapeInfo, void *vparams, void *vz, Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto x = reinterpret_cast(vx); auto z = reinterpret_cast(vz); auto params = reinterpret_cast(vparams); auto reductionPointer = reinterpret_cast(vreductionPointer); __shared__ Nd4jLong xEws; __shared__ Nd4jLong zEws; __shared__ char xOrder; __shared__ char zOrder; __shared__ Nd4jLong length; if (threadIdx.x == 0) { xEws = shape::elementWiseStride(xShapeInfo); zEws = shape::elementWiseStride(zShapeInfo); xOrder = shape::order(xShapeInfo); zOrder = shape::order(zShapeInfo); length = shape::length(xShapeInfo); } __syncthreads(); auto tid = blockIdx.x * blockDim.x + threadIdx.x; int totalThreads = gridDim.x * blockDim.x; if(xEws > 0 && zEws > 0 && xOrder == zOrder) { for (int i = tid; i < length; i += totalThreads) z[i * zEws] = OpType::op(x[i * xEws], params); } else { if(vx == vz) { for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); } } } }; template template _CUDA_H void TransformAny::intermediateShaped(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShape, int xRank, void *extraParams, void *z, Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { transformAnySimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); nd4j::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); } BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); } }