2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
// @author sgazeos@gmail.com
//
#include <ops/declarable/helpers/axis.h>
#include <helpers/PointersManager.h>
#include <helpers/TAD.h>
#include <array>
#include <helpers/ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
2019-09-11 20:04:43 +02:00
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// extract patches kernel
// - theSame - SAME or VALID - output format
// - batchCount - batches - the first dimension of input
// - sizeRow, sizeCol - rows and cols sizes for batch
// - rowDim, colDim - rows and cols dimensions for input patches
// - outRowDim, outColDim - rows and cols dimensions for output patches
// - strideRow, strideCol - step between input elements with patches
// - rateRow, rateCol - counts for input patches
// - rowCast, colCast - shifts for output placement (1 or 0)
// - lastDim - last dimension of input/output
// - input - input tensor buffer
// - patchShape - input patch TAD shape
// - inputOffsets - input TAD offsets
// - output - output tensor buffer
// - outTadShape - output TAD shape
// - outputOffsets - output TAD offsets
2019-08-07 14:29:17 +02:00
//
2019-06-06 14:21:15 +02:00
template <typename T>
2019-08-07 14:29:17 +02:00
static __global__ void globalExtractPatchesKernel(bool theSame, int batchCount, int sizeRow, int sizeCol, int rowDim, int colDim, int outRowDim, int outColDim, int strideRow, int strideCol, int rateRow, int rateCol, int rowCast, int colCast, int lastDim, T* input, Nd4jLong* patchShape, Nd4jLong* inputOffsets, T* output, Nd4jLong* outTadShape, Nd4jLong* outputOffsets) {
auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
2019-09-11 20:04:43 +02:00
// batch input by 3 last dims and extrapole input onto output with outColDim/outRowDim
2019-08-07 14:29:17 +02:00
for (Nd4jLong batch = start; batch < batchCount; batch += step) {
auto patch = input + inputOffsets[batch];// listOfMatricies->at(batch);
auto outMatrix = output + outputOffsets[batch]; //listOfOutputs->at(batch);
for (Nd4jLong i = 0; i < outRowDim; i++) {
for (Nd4jLong j = 0; j < outColDim; j++) {
Nd4jLong pos = 0;
auto rowStart = i * strideRow - (theSame?rowCast:0);
auto colStart = j * strideCol - (theSame?colCast:0);
auto rowEnd = rowStart + sizeRow * rateRow;
auto colEnd = colStart + sizeCol * rateCol;
if (!theSame) {
rowEnd = math::nd4j_min(rowStart + sizeRow * rateRow, Nd4jLong (rowDim));
colEnd = math::nd4j_min(colStart + sizeCol * rateCol, Nd4jLong (colDim));
}
2019-09-11 20:04:43 +02:00
for (auto row = rowStart; row < rowEnd; row += rateRow) {
for (auto col = colStart; col < colEnd; col += rateCol) {
2019-08-07 14:29:17 +02:00
for (auto pixel = 0; pixel < lastDim; pixel++) {
Nd4jLong zPos[] = {i, j, pos};
Nd4jLong xPos[] = {row, col, pixel};
2019-09-11 20:04:43 +02:00
bool setUp =
(theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) || (!theSame);
2019-08-07 14:29:17 +02:00
if (setUp) { // VALID or SAME cases
2019-09-11 19:12:09 +02:00
outMatrix[shape::getOffset(outTadShape, zPos)] = patch[shape::getOffset(patchShape, xPos)];
2019-08-07 14:29:17 +02:00
}
pos++;
}
2019-09-11 20:04:43 +02:00
}
}
2019-08-07 14:29:17 +02:00
}
}
2019-06-06 14:21:15 +02:00
}
2019-08-07 14:29:17 +02:00
2019-06-06 14:21:15 +02:00
}
2019-09-11 20:04:43 +02:00
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
2019-06-06 14:21:15 +02:00
template <typename T>
2019-08-07 14:29:17 +02:00
static void _extractPatches(nd4j::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int strideRow, int strideCol, int rateRow, int rateCol, bool theSame){
NDArray::prepareSpecialUse({output}, {images});
std::vector<int> restDims({1, 2, 3}); // the first and the last dims
// 3D matricies - 2D matricies of vectors (if last dim is greater than 1)
//int e = 0;
const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1);
const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1);
const int ksize = ksizeRowsEffective * ksizeColsEffective;
Nd4jLong lastDim = images->sizeAt(3);
Nd4jLong outLastDim = output->sizeAt(3);
Nd4jLong rowDim = images->sizeAt(1);
Nd4jLong colDim = images->sizeAt(2);
Nd4jLong outRowDim = output->sizeAt(1);
Nd4jLong outColDim = output->sizeAt(2);
2019-09-11 20:04:43 +02:00
auto rowCast = 1;
auto colCast = 1;
// validate shifts
2019-08-07 14:29:17 +02:00
if (sizeRow * rateRow < 3)
rowCast = 0;
if (sizeCol * rateCol < 3)
colCast = 0;
2019-06-06 14:21:15 +02:00
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(images->getShapeInfo(), restDims.data(), restDims.size());
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), restDims.data(), restDims.size());
2019-09-11 20:04:43 +02:00
int batchCount = packX.numberOfTads();
2019-06-06 14:21:15 +02:00
PointersManager manager(context, "helpers::extractPatches");
2019-08-07 14:29:17 +02:00
auto stream = context->getCudaStream();
auto imagesBuffer = reinterpret_cast<T*>(images->specialBuffer());
auto outputBuffer = reinterpret_cast<T*>(output->specialBuffer());
2019-09-11 20:04:43 +02:00
2019-08-07 14:29:17 +02:00
globalExtractPatchesKernel<T><<<128, 128, 1024, *stream>>>(theSame, batchCount, sizeRow, sizeCol,
rowDim, colDim, outRowDim, outColDim, strideRow, strideCol, rateRow, rateCol, rowCast, colCast, lastDim,
imagesBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(),
packZ.specialOffsets());
2019-09-11 20:04:43 +02:00
2019-06-06 14:21:15 +02:00
manager.synchronize();
2019-08-07 14:29:17 +02:00
NDArray::registerSpecialUse({output}, {images});
2019-06-06 14:21:15 +02:00
}
BUILD_SINGLE_TEMPLATE(template void _extractPatches, (nd4j::LaunchContext * context, NDArray* input, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame), LIBND4J_TYPES);
void extractPatches(nd4j::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame){
auto xType = images->dataType();
BUILD_SINGLE_SELECTOR(xType, _extractPatches, (context, images, output, sizeRow, sizeCol, stradeRow, stradeCol, rateRow, rateCol, theSame), LIBND4J_TYPES);
}
}
}
}