/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author George A. Shulinok // #include #include #include #include #include namespace nd4j { namespace ops { namespace helpers { template static __global__ void matrixBandKernel(void* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong lowerBand, Nd4jLong upperBand, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong* tadInputOffsets, Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong* tadOutputOffsets, Nd4jLong numTads, Nd4jLong inputLength) { int totalThreads = blockDim.x; Nd4jLong rows = shape::sizeAt(inputShape, -2); Nd4jLong cols = shape::sizeAt(inputShape, -1); for (Nd4jLong e = blockIdx.x; e < numTads; e += gridDim.x) { auto yOffset = tadInputOffsets[e]; auto xOffset = tadOutputOffsets[e]; for (Nd4jLong i = blockIdx.y; i < rows; i += gridDim.y) { for (Nd4jLong j = threadIdx.x; j < cols; j += totalThreads) { Nd4jLong coords[2] = {i, j}; Nd4jLong tadOffsetOut = shape::getOffset(tadOnlyOutputShapeInfo, coords); Nd4jLong tadOffsetIn = shape::getOffset(tadOnlyInputShapeInfo, coords); //shape::getIndexOffset(j, tadOnlyOutputShapeInfo) if (i >= j) { // check lower diagonals if (lowerBand > 0) { if ((i - j) > lowerBand) *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = T(0); else *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = *( reinterpret_cast(inputBuffer) + yOffset + tadOffsetIn); } } else if (j > i) { if (upperBand > 0) if ((j - i) > upperBand) *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = T(0); else *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = *( reinterpret_cast(inputBuffer) + yOffset + tadOffsetIn); } // if ((i >= j) && (i - j) <= lowerBand && (j - i) <= upperBand) // with in band // *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = *(reinterpret_cast(inputBuffer) + yOffset + tadOffsetIn); //else // *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = T(0); } } } } template void matrixBandPart_(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { dim3 launchDims(256, 512, 8192); auto stream = context->getCudaStream(); std::vector lastDims({input->rankOf() - 2, input->rankOf() - 1}); std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), lastDims); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), lastDims); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), lastDims); const Nd4jLong numTads = packX.numberOfTads(); if (!input->isActualOnDeviceSide()) input->syncToDevice(); if (!input->isActualOnDeviceSide()) input->syncToDevice(); matrixBandKernel<<>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), lowerBand, upperBand, packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets(), numTads, input->lengthOf()); } void matrixBandPart(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, (context, input, output, lowerBand, upperBand), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template void matrixBandPart_, (nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand), FLOAT_TYPES); } } }