cavis/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu

110 lines
5.1 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 30.11.17.
//
#include <ops/declarable/helpers/im2col.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
// input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW]
template <typename T>
__global__ static void im2colCuda(const void *in, void *out,
const Nd4jLong *inShapeInfo, const Nd4jLong *outShapeInfo,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const double zeroPadValD) {
T zeroPadVal = static_cast<T>(zeroPadValD); //Value to use when value is padding. Usually 0 but not always
const auto im = reinterpret_cast<const T*>(in);
auto col = reinterpret_cast<T*>(out);
__shared__ Nd4jLong colLen, *colStrides, *imStrides, *colShape, *colIndices;
__shared__ int iH, iW, colRank;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
colIndices = reinterpret_cast<Nd4jLong*>(shmem);
colRank = shape::rank(outShapeInfo);
colLen = shape::length(outShapeInfo);
colShape = shape::shapeOf(const_cast<Nd4jLong*>(outShapeInfo));
colStrides = shape::stride(outShapeInfo);
imStrides = shape::stride(inShapeInfo);
iH = inShapeInfo[3];
iW = inShapeInfo[4];
}
__syncthreads();
const auto colInd = blockIdx.x * blockDim.x + threadIdx.x;
if(colInd >= colLen) return;
const auto indexes = colIndices + threadIdx.x * colRank;
shape::index2coords(colRank, colShape, colInd, colLen, indexes);
const auto imh = (-pH + indexes[2] * dH) + indexes[4] * sH;
const auto imw = (-pW + indexes[3] * dW) + indexes[5] * sW;
const auto colBuff = col + indexes[0]*colStrides[0] + indexes[1]*colStrides[1] + indexes[2]*colStrides[2] + indexes[3]*colStrides[3] + indexes[4]*colStrides[4] + indexes[5]*colStrides[5];
const auto imBuff = im + indexes[0]*imStrides[0] + indexes[1]*imStrides[1] + imh*imStrides[2] + imw*imStrides[3];
if (static_cast<unsigned>(imh) >= static_cast<unsigned>(iH) || static_cast<unsigned>(imw) >= static_cast<unsigned>(iW))
*colBuff = zeroPadVal;
else
*colBuff = *imBuff;
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *in, void *out, const Nd4jLong *inShapeInfo, const Nd4jLong *outShapeInfo, int kY, int kX, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) {
im2colCuda<T><<<blocksPerGrid, threadsPerBlock, threadsPerBlock * sizeof(Nd4jLong) * 6 /* rank of out = 6 */, *context.getCudaStream()>>>(in, out, inShapeInfo, outShapeInfo, kY, kX, sH, sW, pH, pW, dH, dW, zeroPadVal);
}
//////////////////////////////////////////////////////////////////////////
void im2col(nd4j::LaunchContext & context, const NDArray& in, NDArray& out, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) {
if(!in.isActualOnDeviceSide()) in.syncToDevice();
const int threadsPerBlock = 512;
const int blocksPerGrid = (out.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_SINGLE_SELECTOR(out.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, in.getSpecialBuffer(), out.getSpecialBuffer(), in.getSpecialShapeInfo(), out.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal.e<double>(0)), FLOAT_TYPES);
in.tickReadDevice();
out.tickWriteDevice();
}
BUILD_SINGLE_TEMPLATE(template void im2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *in, void *out, const Nd4jLong *inShapeInfo, const Nd4jLong *outShapeInfo, const int kY, const int kX, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadVal), FLOAT_TYPES);
}
}
}