cavis/libnd4j/include/ops/declarable/helpers/cuda/stack.cu

366 lines
14 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by Yurii Shyrma on 02.01.2018
//
#include <ops/declarable/helpers/stack.h>
#include <helpers/ShapeUtils.h>
#include <array/ResultSet.h>
#include <exceptions/cuda_exception.h>
#include <helpers/TAD.h>
#include <helpers/PointersManager.h>
#include <helpers/ConstantTadHelper.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void stackScalarsCuda(void* pVx, void* vz, const Nd4jLong* zShapeInfo) {
T* z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, totalThreads;
if (threadIdx.x == 0) {
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
const T *x = reinterpret_cast<const T*>(reinterpret_cast<void**>(pVx)[i]);
z[shape::getIndexOffset(i, zShapeInfo)] = *x;
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void stackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
void* pVx, void* vz, const Nd4jLong* zShapeInfo) {
stackScalarsCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(pVx, vz, zShapeInfo);
}
///////////////////////////////////////////////////////////////////
template <typename T>
static void stack_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
const int numOfSubArrs = inArrs.size();
NDArray::prepareSpecialUse({&output}, inArrs);
if(inArrs[0]->rankOf() == 0) {
std::vector<void*> hInBuffers(numOfSubArrs);
for(int i = 0; i < numOfSubArrs; ++i)
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
PointersManager manager(context, "helpers::stack cuda");
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
stackScalarsCudaLauncher<T>(blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, output.specialBuffer(), output.specialShapeInfo());
manager.synchronize();
}
else {
auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim}));
Nd4jLong* zTadShapeInfo = zTadPack.primaryShapeInfo();
for (uint i = 0; i < numOfSubArrs; ++i) {
void* zBuff = output.specialBufferWithOffset(zTadPack.primaryOffsets()[i]);
NativeOpExecutioner::execTransformAny(context, transform::Assign,
nullptr, inArrs[i]->getShapeInfo(), inArrs[i]->getSpecialBuffer(), inArrs[i]->getSpecialShapeInfo(),
nullptr, zTadShapeInfo, zBuff, zTadPack.specialShapeInfo(),
nullptr, nullptr, nullptr, false/*allowParallelism*/);
}
}
NDArray::registerSpecialUse({&output}, inArrs);
}
////////////////////////////////////////////////////////////////////////
void stack(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (context, inArrs, output, dim), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void stack_ , (sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim), LIBND4J_TYPES);
///////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void unstackScalarsCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz) {
const T* x = reinterpret_cast<const T*>(vx);
__shared__ Nd4jLong xLen, totalThreads;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < xLen; i += totalThreads) {
T* z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[i]);
*z = x[shape::getIndexOffset(i, xShapeInfo)];
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void unstackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo, void* pVz) {
unstackScalarsCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, pVz);
}
///////////////////////////////////////////////////////////////////
template <typename T>
static void unstack_(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
const int numOfSubArrs = outArrs.size();
// NDArray::prepareSpecialUse(outArrs, {&input});
input.syncToDevice();
for (const auto a : outArrs)
a->getDataBuffer()->allocateSpecial();
if(outArrs[0]->rankOf() == 0) {
std::vector<void*> hOutBuffers(numOfSubArrs);
for(int i = 0; i < numOfSubArrs; ++i)
hOutBuffers[i] = outArrs[i]->getSpecialBuffer();
PointersManager manager(context, "helpers::unstack cuda");
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
unstackScalarsCudaLauncher<T>(blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers);
manager.synchronize();
}
else {
auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim}));
Nd4jLong* xTadShapeInfo = xTadPack.primaryShapeInfo();
for (uint i = 0; i < numOfSubArrs; ++i) {
void* xBuff = input.specialBufferWithOffset(xTadPack.primaryOffsets()[i]);
NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign,
nullptr, xTadShapeInfo, xBuff, xTadPack.specialShapeInfo(),
nullptr, outArrs[i]->getShapeInfo(), outArrs[i]->specialBuffer(), outArrs[i]->specialShapeInfo(),
nullptr, nullptr, nullptr, false/*allowParallelism*/);
}
}
// NDArray::registerSpecialUse(outArrs, {&input});
input.tickReadDevice();
for (const auto p : outArrs)
p->tickWriteDevice();
}
////////////////////////////////////////////////////////////////////////
void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (context, input, outArrs, dim), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim), LIBND4J_TYPES);
///////////////////////////////////////////////////////////////////
// template <typename T>
// static __global__ void unstackCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) {
// const T* x = reinterpret_cast<const T*>(vx);
// __shared__ Nd4jLong xLen, totalThreads;
// __shared__ int xRank;
// if (threadIdx.x == 0) {
// xLen = shape::length(xShapeInfo);
// xRank = shape::rank(xShapeInfo);
// totalThreads = gridDim.x * blockDim.x;
// }
// __syncthreads();
// const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
// Nd4jLong coords[MAX_RANK];
// for (uint64_t i = tid; i < xLen; i += totalThreads) {
// shape::index2coords(i, xShapeInfo, coords);
// const auto xOffset = shape::getOffset(xShapeInfo, coords);
// T *z = reinterpret_cast<T*>(reinterpret_cast<void **>(pVz)[coords[axis]]);
// for (uint j = axis; j < xRank - 1; ++j) // shift coords staring from axis position
// coords[j] = coords[j + 1];
// const auto zOffset = shape::getOffset(zTadShapeInfo, coords);
// z[zOffset] = x[xOffset];
// }
// }
// ///////////////////////////////////////////////////////////////////
// template<typename T>
// __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
// const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) {
// unstackCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis);
// }
// BUILD_SINGLE_TEMPLATE(template void unstackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES);
// ///////////////////////////////////////////////////////////////////
// void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<const NDArray*>& outArrs, const int axis) {
// const int threadsPerBlock = MAX_NUM_THREADS / 2;
// const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
// const int numOfSubArrs = outArrs.size();
// std::vector<void*> hOutBuffers(numOfSubArrs);
// for(int i = 0; i < numOfSubArrs; ++i)
// hOutBuffers[i] = outArrs[i]->getSpecialBuffer();
// PointersManager manager(context, "helpers::unstack");
// void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
// for(uint i = 0; i < numOfSubArrs; ++i)
// outArrs[i]->syncToDevice();
// input.syncToDevice();
// BUILD_SINGLE_SELECTOR(input.dataType(), unstackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers, outArrs[0]->getSpecialShapeInfo(), axis), LIBND4J_TYPES);
// manager.synchronize();
// for(uint i = 0; i < numOfSubArrs; ++i)
// outArrs[i]->tickReadDevice();
// input.tickWriteDevice();
// }
// ///////////////////////////////////////////////////////////////////
// template <typename T>
// static __global__ void stackCuda(void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) {
// T* z = reinterpret_cast<T*>(vz);
// __shared__ Nd4jLong zLen, totalThreads;
// __shared__ int zRank;
// if (threadIdx.x == 0) {
// zLen = shape::length(zShapeInfo);
// zRank = shape::rank(zShapeInfo);
// totalThreads = gridDim.x * blockDim.x;
// }
// __syncthreads();
// const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
// Nd4jLong coords[MAX_RANK];
// for (uint64_t i = tid; i < zLen; i += totalThreads) {
// shape::index2coords(i, zShapeInfo, coords);
// const auto zOffset = shape::getOffset(zShapeInfo, coords);
// const T *x = reinterpret_cast<const T*>(reinterpret_cast<void**>(pVx)[coords[axis]]);
// for (uint j = axis; j < zRank - 1; ++j) // shift coords staring from axis position
// coords[j] = coords[j + 1];
// const auto xOffset = shape::getOffset(xTadShapeInfo, coords);
// z[zOffset] = x[xOffset];
// }
// }
// ///////////////////////////////////////////////////////////////////
// template<typename T>
// __host__ static void stackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
// void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) {
// stackCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(pVx, xTadShapeInfo, vz, zShapeInfo, axis);
// }
// BUILD_SINGLE_TEMPLATE(template void stackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
// ///////////////////////////////////////////////////////////////////
// void stack(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
// const int threadsPerBlock = MAX_NUM_THREADS / 2;
// const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
// const int numOfSubArrs = inArrs.size();
// std::vector<void*> hInBuffers(numOfSubArrs);
// for(int i = 0; i < numOfSubArrs; ++i)
// hInBuffers[i] = inArrs[i]->getSpecialBuffer();
// PointersManager manager(context, "helpers::stack");
// void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
// for(uint i = 0; i < numOfSubArrs; ++i)
// inArrs[i]->syncToDevice();
// output.syncToDevice();
// BUILD_SINGLE_SELECTOR(output.dataType(), stackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, inArrs[0]->getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), axis), LIBND4J_TYPES);
// manager.synchronize();
// for(uint i = 0; i < numOfSubArrs; ++i)
// inArrs[i]->tickReadDevice();
// output.tickWriteDevice();
// }
}
}
}