/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com) // // #include #include #include #include #include #include #include // #include // #include using namespace simdOps; namespace nd4j { namespace helpers { //////////////////////////////////////////////////////////////////////// template __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { const auto x = reinterpret_cast(vx); const auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); zLen = shape::length(zShapeInfo); totalThreads = gridDim.x * blockDim.x; } __syncthreads(); Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (Nd4jLong i = tid; i < zLen; i += totalThreads) { shape::index2coords(i, zShapeInfo, zCoords); for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { if(ix >= 0) { if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) xCoords[ix--] = zCoords[iz]; else xCoords[ix--] = 0; } if(iy >= 0) { if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) yCoords[iy--] = zCoords[iz]; else yCoords[iy--] = 0; } } const auto xOffset = shape::getOffset(xShapeInfo, xCoords); const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords); z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } } //////////////////////////////////////////////////////////////////////// template template void TrueBroadcastHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { trueBroadcastCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } ////////////////////////////////////////////////////////////////////////// template void TrueBroadcastHelper::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr}); DISPATCH_BY_OPNUM_TTT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_OPS)); NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr}); manager.synchronize(); } //////////////////////////////////////////////////////////////////////// template __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { const auto x = reinterpret_cast(vx); const auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); zLen = shape::length(zShapeInfo); totalThreads = gridDim.x * blockDim.x; } __syncthreads(); Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (Nd4jLong i = tid; i < zLen; i += totalThreads) { shape::index2coords(i, zShapeInfo, zCoords); for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { if(ix >= 0) if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) xCoords[ix--] = zCoords[iz]; else xCoords[ix--] = 0; if(iy >= 0) if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) yCoords[iy--] = zCoords[iz]; else yCoords[iy--] = 0; } const auto xOffset = shape::getOffset(xShapeInfo, xCoords); const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords); z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } } //////////////////////////////////////////////////////////////////////// template template void TrueBroadcastBoolHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { trueBroadcastBoolCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } ////////////////////////////////////////////////////////////////////////// template void TrueBroadcastBoolHelper::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr}); DISPATCH_BY_OPNUM_TT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_BOOL_OPS)); NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr}); manager.synchronize(); } //////////////////////////////////////////////////////////////////////// template __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { const auto x = reinterpret_cast(vx); const auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); zLen = shape::length(zShapeInfo); totalThreads = gridDim.x * blockDim.x; } __syncthreads(); Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (Nd4jLong i = tid; i < zLen; i += totalThreads) { shape::index2coords(i, zShapeInfo, zCoords); for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { if(ix >= 0) if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) xCoords[ix--] = zCoords[iz]; else xCoords[ix--] = 0; if(iy >= 0) if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) yCoords[iy--] = zCoords[iz]; else yCoords[iy--] = 0; } const auto xOffset = shape::getOffset(xShapeInfo, xCoords); const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords); z[zOffset] = OpType::op(x[xOffset], y[yOffset]); } } //////////////////////////////////////////////////////////////////////// template template void TrueBroadcastIntHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { trueBroadcastIntCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } ////////////////////////////////////////////////////////////////////////// template void TrueBroadcastIntHelper::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec"); NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr}); DISPATCH_BY_OPNUM_T(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_INT_OPS)); NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr}); manager.synchronize(); } BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES); } }