/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by Yurii Shyrma on 02.01.2018 // #include #include #include #include #include #include #include namespace nd4j { namespace ops { namespace helpers { template static __global__ void stackKernel(void** inputList, void** inputShapeList, int inputListLength, Nd4jLong arrLen, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* tadShape, Nd4jLong *tadOffsets) { T* z = reinterpret_cast(vz); if(tadShape == nullptr) { // scalar case for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < inputListLength; i += gridDim.x * blockDim.x) z[shape::getIndexOffset(i, zShapeInfo, inputListLength)] = reinterpret_cast(inputList[i])[0]; } else { for (int t = blockIdx.x; t < inputListLength; t += gridDim.x) { auto tZ = z + tadOffsets[t]; auto tX = reinterpret_cast(inputList[t]); auto xShapeInfo = reinterpret_cast(inputShapeList[t]); for (int e = threadIdx.x; e < arrLen; e += blockDim.x) tZ[shape::getIndexOffset(e, tadShape, arrLen)] = tX[shape::getIndexOffset(e, xShapeInfo, arrLen)]; } } } /////////////////////////////////////////////////////////////////// template static void stack_(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray* outArr, const int dim) { const bool scalarCase = inArrs[0]->isScalar(); const int threadsPerBlock = MAX_NUM_THREADS / 2; const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size(); NDArray::prepareSpecialUse({outArr}, {}); // FIXME: !!! for (auto v:inArrs) NDArray::prepareSpecialUse({}, {v}); std::vector inputList(inArrs.size()); std::vector inputShapeList(inArrs.size()); for (size_t i = 0; i < inputList.size(); ++i) { inputList[i] = inArrs[i]->getSpecialBuffer(); inputShapeList[i] = inArrs[i]->getSpecialShapeInfo(); } PointersManager manager(context, "helpers::stack"); auto dInBuffers = (void **) manager.replicatePointer(inputList.data(), inputList.size() * sizeof(Nd4jLong*)); auto dInShapeInfo = (void **) manager.replicatePointer(inputShapeList.data(), inputShapeList.size() * sizeof(Nd4jLong*)); if(scalarCase) { stackKernel<<getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), outArr->getSpecialShapeInfo(), nullptr, nullptr); } else { std::vector axis = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim}); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(outArr->getShapeInfo(), axis); stackKernel<<getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), nullptr, packZ.specialShapeInfo(), packZ.specialOffsets()); } manager.synchronize(); NDArray::registerSpecialUse({outArr}, {}); // FIXME: !!! for (auto v:inArrs) NDArray::registerSpecialUse({}, {v}); } void stack(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray* outArr, const int dim) { BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (context, inArrs, outArr, dim), LIBND4J_TYPES); } BUILD_SINGLE_TEMPLATE(template void stack_ , (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES); } } }