[WIP] More of CUDA (#63)

* less spam

Signed-off-by: raver119 <raver119@gmail.com>

* flatten kernel

Signed-off-by: raver119 <raver119@gmail.com>

* adjust_hue/adjust_saturation tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* adjust_hue cuda single

Signed-off-by: raver119 <raver119@gmail.com>

* adjust_hue cuda batch

Signed-off-by: raver119 <raver119@gmail.com>

* adjust_saturation cuda

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-17 14:02:01 +03:00 committed by AlexDBlack
parent cf2311859a
commit 91a8fb0d90
11 changed files with 267 additions and 48 deletions

View File

@ -1327,7 +1327,7 @@ void NativeOps::concat(
// take into account indices for first array // take into account indices for first array
auto axisSize = shape::sizeAt(reinterpret_cast<Nd4jLong*>(inputShapeInfo[0]), axis); auto axisSize = shape::sizeAt(reinterpret_cast<Nd4jLong*>(inputShapeInfo[0]), axis);
indices[0][2 * axis + 1] = axisSize; indices[0][2 * axis + 1] = axisSize;
printf("The axe size is %lld\n", axisSize); //printf("The axe size is %lld\n", axisSize);
// loop through the rest of input arrays // loop through the rest of input arrays
for(int i = 1; i < numArrays; ++i) { for(int i = 1; i < numArrays; ++i) {
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from

View File

@ -42,7 +42,7 @@ namespace nd4j {
} }
char order = (char) INT_ARG(0); char order = (char) INT_ARG(0);
helpers::flatten(arrays, output, order); helpers::flatten(block.launchContext(), arrays, output, order);
return Status::OK(); return Status::OK();
} }

View File

@ -25,7 +25,7 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static FORCEINLINE void rgb_to_hv(nd4j::LaunchContext * context, T r, T g, T b, T* h, T* v_min, T* v_max) { static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) {
T v_mid; T v_mid;
int h_category; int h_category;
// According to the figures in: // According to the figures in:
@ -84,7 +84,7 @@ namespace helpers {
} }
template <typename T> template <typename T>
static FORCEINLINE void hv_to_rgb(nd4j::LaunchContext * context, T h, T v_min, T v_max, T* r, T* g, T* b) { static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) {
int h_category = static_cast<int>(h); int h_category = static_cast<int>(h);
T ratio = h - (T)h_category; T ratio = h - (T)h_category;
bool increase = ((h_category & 0x1) == 0); bool increase = ((h_category & 0x1) == 0);

View File

@ -26,7 +26,7 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static FORCEINLINE void rgb_to_hsv(nd4j::LaunchContext * context, T r, T g, T b, T* h, T* s, T* v) { static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) {
T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b)); T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
T range = vv - nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b)); T range = vv - nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b));
if (vv > 0) { if (vv > 0) {
@ -54,7 +54,7 @@ namespace helpers {
} }
template <typename T> template <typename T>
static FORCEINLINE void hsv_to_rgb(nd4j::LaunchContext * context, T h, T s, T v, T* r, T* g, T* b) { static FORCEINLINE _CUDA_HD void hsv_to_rgb(T h, T s, T v, T* r, T* g, T* b) {
T c = s * v; T c = s * v;
T m = v - c; T m = v - c;
T dh = h * 6; T dh = h * 6;

View File

@ -42,7 +42,7 @@ namespace helpers {
auto o = bOut + e * numChannels; auto o = bOut + e * numChannels;
T h, v_min, v_max; T h, v_min, v_max;
helpers::rgb_to_hv(context, i[0], i[1], i[2], &h, &v_min, &v_max); helpers::rgb_to_hv(i[0], i[1], i[2], &h, &v_min, &v_max);
h += delta * kChannelRange; h += delta * kChannelRange;
while (h < (T) 0.) while (h < (T) 0.)
@ -51,7 +51,7 @@ namespace helpers {
while (h >= (T) kChannelRange) while (h >= (T) kChannelRange)
h -= (T) kChannelRange; h -= (T) kChannelRange;
helpers::hv_to_rgb(context, h, v_min, v_max, o, o + 1, o + 2); helpers::hv_to_rgb(h, v_min, v_max, o, o + 1, o + 2);
} }
} else { } else {
auto tadsChannelsIn = array->allTensorsAlongDimension({0}); auto tadsChannelsIn = array->allTensorsAlongDimension({0});
@ -76,7 +76,7 @@ namespace helpers {
auto _bo = outputB + e; auto _bo = outputB + e;
T h, v_min, v_max; T h, v_min, v_max;
helpers::rgb_to_hv(context, _ri[0], _gi[0], _bi[0], &h, &v_min, &v_max); helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max);
h += delta * kChannelRange; h += delta * kChannelRange;
while (h < (T) 0) while (h < (T) 0)
@ -85,7 +85,7 @@ namespace helpers {
while (h >= (T) kChannelRange) while (h >= (T) kChannelRange)
h -= (T) kChannelRange; h -= (T) kChannelRange;
helpers::hv_to_rgb(context, h, v_min, v_max, _ro, _go, _bo); helpers::hv_to_rgb(h, v_min, v_max, _ro, _go, _bo);
} }
delete tadsChannelsIn; delete tadsChannelsIn;

View File

@ -43,10 +43,10 @@ namespace helpers {
T h, s, v; T h, s, v;
// Convert the RGB color to Hue/V-range. // Convert the RGB color to Hue/V-range.
helpers::rgb_to_hsv(context, i[0], i[1], i[2], &h, &s, &v); helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v);
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta)); s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
// Convert the hue and v-range back into RGB. // Convert the hue and v-range back into RGB.
helpers::hsv_to_rgb(context, h, s, v, o, o + 1, o + 2); helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2);
} }
} else { } else {
auto tadsChannelsIn = array->allTensorsAlongDimension({0}); auto tadsChannelsIn = array->allTensorsAlongDimension({0});
@ -72,10 +72,10 @@ namespace helpers {
T h, s, v; T h, s, v;
// Convert the RGB color to Hue/V-range. // Convert the RGB color to Hue/V-range.
helpers::rgb_to_hsv(context, _ri[0], _gi[0], _bi[0], &h, &s, &v); helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v);
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta)); s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
// Convert the hue and v-range back into RGB. // Convert the hue and v-range back into RGB.
helpers::hsv_to_rgb(context, h, s, v, _ro, _go, _bo); helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo);
} }
delete tadsChannelsIn; delete tadsChannelsIn;

View File

@ -58,7 +58,7 @@ namespace nd4j {
} }
} }
void flatten(std::vector<NDArray*> &inputs, NDArray *output, char order) { void flatten(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order) {
BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (inputs, output, order), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (inputs, output, order), LIBND4J_TYPES);
} }
} }

View File

@ -19,14 +19,114 @@
// //
#include <ops/declarable/helpers/adjust_hue.h> #include <ops/declarable/helpers/adjust_hue.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static void _adjust_hue_single(NDArray *array, NDArray *output, float delta, bool isNHWC) { static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
int numChannels = 3;
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
auto bIn = reinterpret_cast<T*>(xBuffer);
auto bOut = reinterpret_cast<T*>(zBuffer);
static const int kChannelRange = 6;
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
auto i = bIn + e * numChannels;
auto o = bOut + e * numChannels;
T h, v_min, v_max;
helpers::rgb_to_hv(i[0], i[1], i[2], &h, &v_min, &v_max);
h += delta * kChannelRange;
while (h < (T) 0.)
h += (T) kChannelRange;
while (h >= (T) kChannelRange)
h -= (T) kChannelRange;
helpers::hv_to_rgb(h, v_min, v_max, o, o + 1, o + 2);
}
}
template <typename T>
static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) {
int numChannels = 3;
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
static const int kChannelRange = 6;
auto bufferR = reinterpret_cast<T *>(xBuffer) + xOffsets[0];
auto bufferG = reinterpret_cast<T *>(xBuffer) + xOffsets[1];
auto bufferB = reinterpret_cast<T *>(xBuffer) + xOffsets[2];
auto outputR = reinterpret_cast<T *>(zBuffer) + zOffsets[0];
auto outputG = reinterpret_cast<T *>(zBuffer) + zOffsets[1];
auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2];
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
T h, v_min, v_max;
helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max);
h += delta * kChannelRange;
while (h < (T) 0)
h += (T) kChannelRange;
while (h >= (T) kChannelRange)
h -= (T) kChannelRange;
helpers::hv_to_rgb(h, v_min, v_max, _ro, _go, _bo);
}
}
template <typename T>
static void _adjust_hue_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
// numChannels is always 3
auto tuples = array->lengthOf() / 3;
if (isNHWC) {
adjustHueSingleNHWCKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta);
} else {
// TODO: check this one
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {1, 2});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {1, 2});
auto tadLength = shape::length(packX.primaryShapeInfo());
adjustHueSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
}
}
template <typename T>
static void _adjust_hue_batch(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
auto xType = array->dataType();
// numChannels is always 3
auto tuples = array->lengthOf() / 3;
if (isNHWC) {
// in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES);
} else {
// TODO: check this one
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {0, 2, 3});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {0, 2, 3});
auto tadLength = shape::length(packX.primaryShapeInfo());
adjustHueSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
}
} }
void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
@ -34,25 +134,11 @@ namespace helpers {
float d = delta->e<float>(0); float d = delta->e<float>(0);
if (array->rankOf() == 4) { if (array->rankOf() == 4) {
auto tadsIn = array->allTensorsAlongDimension({0});
auto tadsOut = output->allTensorsAlongDimension({0});
// FIXME: template selector should be moved out of loop
PRAGMA_OMP_PARALLEL_FOR
for (int e = 0; e < tadsIn->size(); e++) {
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
}
delete tadsIn;
delete tadsOut;
} else { } else {
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (array, output, d, isNHWC);, FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
} }
} }
BUILD_SINGLE_TEMPLATE(template void _adjust_hue_single, (NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES);
} }
} }
} }

View File

@ -19,6 +19,7 @@
// //
#include <ops/declarable/helpers/adjust_saturation.h> #include <ops/declarable/helpers/adjust_saturation.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j { namespace nd4j {
@ -26,8 +27,96 @@ namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static void _adjust_saturation_single(NDArray *array, NDArray *output, float delta, bool isNHWC) { static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
int numChannels = 3;
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
auto bIn = reinterpret_cast<T*>(xBuffer);
auto bOut = reinterpret_cast<T*>(zBuffer);
static const int kChannelRange = 6;
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
auto i = bIn + e * numChannels;
auto o = bOut + e * numChannels;
T h, s, v;
// Convert the RGB color to Hue/V-range.
helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v);
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
// Convert the hue and v-range back into RGB.
helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2);
}
}
template <typename T>
static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) {
int numChannels = 3;
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
static const int kChannelRange = 6;
auto bufferR = reinterpret_cast<T *>(xBuffer) + xOffsets[0];
auto bufferG = reinterpret_cast<T *>(xBuffer) + xOffsets[1];
auto bufferB = reinterpret_cast<T *>(xBuffer) + xOffsets[2];
auto outputR = reinterpret_cast<T *>(zBuffer) + zOffsets[0];
auto outputG = reinterpret_cast<T *>(zBuffer) + zOffsets[1];
auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2];
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
T h, s, v;
// Convert the RGB color to Hue/V-range.
helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v);
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
// Convert the hue and v-range back into RGB.
helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo);
}
}
template <typename T>
static void _adjust_saturation_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
// numChannels is always 3
auto tuples = array->lengthOf() / 3;
if (isNHWC) {
adjustSaturationSingleNHWCKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta);
} else {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {1, 2});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {1, 2});
auto tadLength = shape::length(packX.primaryShapeInfo());
adjustSaturationSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
}
}
template <typename T>
static void _adjust_saturation_batch(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
auto xType = array->dataType();
// numChannels is always 3
auto tuples = array->lengthOf() / 3;
if (isNHWC) {
// in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES);
} else {
// TODO: check this one
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {0, 2, 3});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {0, 2, 3});
auto tadLength = shape::length(packX.primaryShapeInfo());
adjustSaturationSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
}
} }
void adjust_saturation(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { void adjust_saturation(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
@ -35,24 +124,12 @@ namespace helpers {
float d = delta->e<float>(0); float d = delta->e<float>(0);
if (array->rankOf() == 4) { if (array->rankOf() == 4) {
auto tadsIn = array->allTensorsAlongDimension({0}); BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES);
auto tadsOut = output->allTensorsAlongDimension({0});
// FIXME: template selector should be moved out of loop
PRAGMA_OMP_PARALLEL_FOR
for (int e = 0; e < tadsIn->size(); e++) {
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
}
delete tadsIn;
delete tadsOut;
} else { } else {
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (array, output, d, isNHWC);, FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
} }
} }
BUILD_SINGLE_TEMPLATE(template void _adjust_saturation_single, (NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES);
} }
} }

View File

@ -19,12 +19,68 @@
// //
#include <ops/declarable/helpers/flatten.h> #include <ops/declarable/helpers/flatten.h>
#include <helpers/PointersManager.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void flatten(std::vector<NDArray*> &inputs, NDArray *output, char order) { template <typename T>
void _CUDA_G flattenKernel(void **xBuffers, Nd4jLong **xShapeInfos, Nd4jLong *offsets, Nd4jLong numInputs, void *zBuffer, Nd4jLong *zShapeInfo, char order) {
Nd4jLong xCoord[MAX_RANK];
for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) {
auto z = reinterpret_cast<T*>(zBuffer) + offsets[e];
auto xBuffer = reinterpret_cast<T*>(xBuffers[e]);
auto xShapeInfo = xShapeInfos[e];
auto xShape = shape::shapeOf(xShapeInfo);
auto xStride = shape::stride(xShapeInfo);
auto xRank = shape::rank(xShapeInfo);
auto xLength = shape::length(xShapeInfo);
for (uint i = threadIdx.x; i < xLength; i += blockDim.x) {
shape::index2coords(xRank, xShape, i, xLength, xCoord, order);
auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank);
z[i] = xBuffer[xOffset];
}
}
}
template <typename T>
void flatten_(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order) {
PointersManager pm(context, "flatten");
std::vector<void*> hdBuffers(inputs.size());
std::vector<Nd4jLong> hOffsets(inputs.size());
std::vector<Nd4jLong *> hdShapes(inputs.size());
Nd4jLong cOffset = 0;
// calculating offsets in output
for (int e = 0; e < inputs.size(); e++) {
hOffsets[e] = cOffset;
cOffset += inputs[e]->lengthOf();
hdBuffers[e] = inputs[e]->specialBuffer();
hdShapes[e] = inputs[e]->specialShapeInfo();
}
auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*));
auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*));
auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong));
flattenKernel<T><<<256, 512, 8192, *context->getCudaStream()>>>(dBuffers, dShapes, dOffsets, inputs.size(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), order);
pm.synchronize();
}
void flatten(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order) {
for (auto v:inputs)
v->syncToDevice();
BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (context, inputs, output, order), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {});
} }
} }
} }

View File

@ -27,7 +27,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void flatten(std::vector<NDArray*> &inputs, NDArray *output, char order); void flatten(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order);
} }
} }
} }