// @author Yurii Shyrma (iuriish@yahoo.com), created on 19.09.2018
//
#include <ops/declarable/helpers/im2col.h>
namespace nd4j {
namespace ops {
namespace helpers {
// input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW]
template <typename T>
static void im2col_(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) {
auto imBuff = static_cast<T*>(input.getBuffer());
auto colBuff = static_cast<T*>(output.getBuffer());
if (static_cast<unsigned>(imRow) >= static_cast<unsigned>(iH) || static_cast<unsigned>(imCol) >= static_cast<unsigned>(iW))
*col = zeroPadVal;
else
*col = *im;
}
}
}
}
}
}
}
}
void im2col(nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) {
BUILD_SINGLE_TEMPLATE(template void im2col_, (nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal), LIBND4J_TYPES);