/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include #include #include #include #include using namespace simdOps; template __global__ void transformStrictSimple(const void *x, const Nd4jLong *xShapeInfo, int xRank, void *params, void *z, const Nd4jLong *zShapeInfo, int zRank, int *allocationPointer, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { functions::transform::TransformStrict::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); } namespace functions { namespace transform { template _CUDA_H void TransformStrict::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShape, int xRank, void *extraParams, void *z, const Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_STRICT_OPS); DEBUG_KERNEL(stream, opNum); } template template __device__ void TransformStrict::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { auto x = static_cast(vx); auto z = static_cast(vz); auto params = static_cast(vparams); auto reductionPointer = static_cast(vreductionPointer); if(OpType::requiresSpecial) { OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); return; } else { __shared__ Nd4jLong xEws; __shared__ Nd4jLong zEws; __shared__ char xOrder; __shared__ char zOrder; __shared__ Nd4jLong length; if (threadIdx.x == 0) { xEws = shape::elementWiseStride(xShapeInfo); zEws = shape::elementWiseStride(zShapeInfo); xOrder = shape::order(xShapeInfo); zOrder = shape::order(zShapeInfo); length = shape::length(xShapeInfo); } __syncthreads(); auto tid = blockIdx.x * blockDim.x + threadIdx.x; int totalThreads = gridDim.x * blockDim.x; if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { for (int i = tid; i < length; i += totalThreads) z[i * zEws] = OpType::op(x[i * xEws], params); } else { if(vx == vz) { for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo); z[xOffset] = OpType::op(x[xOffset], params); } } else { for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo); auto zOffset = shape::getIndexOffset(i, zShapeInfo); z[zOffset] = OpType::op(x[xOffset], params); } } } } }; template template _CUDA_H void TransformStrict::intermediateShaped(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShape, int xRank, void *extraParams, void *z, const Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { transformStrictSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); sd::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); } BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); } }