// [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
template <typename T>
void col2im_(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) {
auto imBuff = output.bufferAsT<T>();
auto colBuff = input.bufferAsT<T>();
auto imShapeBuffer = output.getShapeInfo();
auto colShapeBuffer = input.getShapeInfo();
auto colShape = shape::shapeOf(colShapeBuffer);
auto colStride = shape::stride(colShapeBuffer);
auto imShape = shape::shapeOf(imShapeBuffer);
auto imStride = shape::stride(imShapeBuffer);
const int bS = imShape[0];
const int iC = imShape[1];
const int kH = colShape[2];
const int kW = colShape[3];
const int oH = colShape[4];
const int oW = colShape[5];
const Nd4jLong colStride0 = colStride[0];
const Nd4jLong colStride1 = colStride[1];
const Nd4jLong colStride2 = colStride[2];
const Nd4jLong colStride3 = colStride[3];
const Nd4jLong colStride4 = colStride[4];
const Nd4jLong colStride5 = colStride[5];
const Nd4jLong imStride0 = imStride[0];
const Nd4jLong imStride1 = imStride[1];
const Nd4jLong imStride2 = imStride[2];
const Nd4jLong imStride3 = imStride[3];
// initial zeroing of image content
const auto imEWS = shape::elementWiseStride(imShapeBuffer);
if (static_cast<unsigned>(imRow) < static_cast<unsigned>(iH) && static_cast<unsigned>(imCol) < static_cast<unsigned>(iW))
*im += *col;
}
}
}
}
}
}
}
}
void col2im(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) {
BUILD_SINGLE_TEMPLATE(template void col2im_, (nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW), LIBND4J_TYPES);