/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver119 on 30.11.17. // #include namespace nd4j { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// // input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW] template __global__ static void im2colCuda(const void *in, void *out, const Nd4jLong *inShapeInfo, const Nd4jLong *outShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadValD) { T zeroPadVal = static_cast(zeroPadValD); //Value to use when value is padding. Usually 0 but not always const auto im = reinterpret_cast(in); auto col = reinterpret_cast(out); __shared__ Nd4jLong colLen, *colStrides, *imStrides, *colShape, *colIndices; __shared__ int iH, iW, colRank; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; colIndices = reinterpret_cast(shmem); colRank = shape::rank(outShapeInfo); colLen = shape::length(outShapeInfo); colShape = shape::shapeOf(const_cast(outShapeInfo)); colStrides = shape::stride(outShapeInfo); imStrides = shape::stride(inShapeInfo); iH = inShapeInfo[3]; iW = inShapeInfo[4]; } __syncthreads(); const auto colInd = blockIdx.x * blockDim.x + threadIdx.x; if(colInd >= colLen) return; const auto indexes = colIndices + threadIdx.x * colRank; shape::index2coords(colRank, colShape, colInd, colLen, indexes); const auto imh = (-pH + indexes[2] * dH) + indexes[4] * sH; const auto imw = (-pW + indexes[3] * dW) + indexes[5] * sW; const auto colBuff = col + indexes[0]*colStrides[0] + indexes[1]*colStrides[1] + indexes[2]*colStrides[2] + indexes[3]*colStrides[3] + indexes[4]*colStrides[4] + indexes[5]*colStrides[5]; const auto imBuff = im + indexes[0]*imStrides[0] + indexes[1]*imStrides[1] + imh*imStrides[2] + imw*imStrides[3]; if (static_cast(imh) >= static_cast(iH) || static_cast(imw) >= static_cast(iW)) *colBuff = zeroPadVal; else *colBuff = *imBuff; } ////////////////////////////////////////////////////////////////////////// template static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *in, void *out, const Nd4jLong *inShapeInfo, const Nd4jLong *outShapeInfo, int kY, int kX, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) { im2colCuda<<>>(in, out, inShapeInfo, outShapeInfo, kY, kX, sH, sW, pH, pW, dH, dW, zeroPadVal); } ////////////////////////////////////////////////////////////////////////// void im2col(nd4j::LaunchContext & context, const NDArray& in, NDArray& out, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { if(!in.isActualOnDeviceSide()) in.syncToDevice(); const int threadsPerBlock = 512; const int blocksPerGrid = (out.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; BUILD_SINGLE_SELECTOR(out.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, in.getSpecialBuffer(), out.getSpecialBuffer(), in.getSpecialShapeInfo(), out.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), FLOAT_TYPES); in.tickReadDevice(); out.tickWriteDevice(); } BUILD_SINGLE_TEMPLATE(template void im2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *in, void *out, const Nd4jLong *inShapeInfo, const Nd4jLong *outShapeInfo, const int kY, const int kX, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadVal), FLOAT_TYPES); } } }