[WIP] More tweaks (#173)

* CUDA empty reduction

Signed-off-by: raver119 <raver119@gmail.com>

* - listdiff synchronization fix for CUDA
- listdiff test

Signed-off-by: raver119 <raver119@gmail.com>

* - IndexReduce ops now allow INDEXING_TYPES output
- topK op accepts only INDEXING_TYPES as output

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-27 10:37:10 +03:00 committed by GitHub
parent e92f7218f3
commit df84bc7255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 217 additions and 137 deletions

View File

@ -3590,8 +3590,8 @@ void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const
if (isS())
throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!");
if (target->dataType() != nd4j::DataType::INT64)
throw std::runtime_error("NDArray::applyIndexReduce operations return INT64");
if (target->dataType() != nd4j::DataType::INT64 && target->dataType() != nd4j::DataType::INT32)
throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64");
void* params = extraParams != nullptr ? const_cast<ExtraArguments*>(extraParams)->argumentsAsT(this->dataType()) : nullptr;

View File

@ -79,9 +79,10 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, int op
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
auto hz = reinterpret_cast<Nd4jLong*>(hZ);
BUILD_SINGLE_SELECTOR(xType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(xType, zType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES, INDEXING_TYPES);
}
////////////////////////////////////////////////////////////////////////
@ -111,9 +112,10 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc,
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
Nd4jLong* hz = reinterpret_cast<Nd4jLong*>(hZ);
BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES);
// BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
}

View File

@ -475,12 +475,12 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc,
auto numBlocks = shape::length(hZShapeInfo);
dim3 launchDims(numBlocks, 256, 32768);
if (zType != nd4j::DataType::INT64)
throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT64 type", zType);
if (zType != nd4j::DataType::INT64 && zType != nd4j::DataType::INT32)
throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType);
auto dz = reinterpret_cast<Nd4jLong*>(dZ);
BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
@ -567,12 +567,12 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc,
// FIXME: we want Z to be one of integer types
//if (!DataTypeUtils::isZ(zType))
// throw nd4j::datatype_exception("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have one of integer types")
if (zType != nd4j::DataType::INT64)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT64 data type", zType);
if (zType != nd4j::DataType::INT64 && zType != nd4j::DataType::INT32)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT32/INT64 data type", zType);
auto dz = reinterpret_cast<Nd4jLong*>(dZ);
BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream,
opNum,
dX, dXShapeInfo, shape::rank(hXShapeInfo),
extraParams,
@ -580,7 +580,7 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc,
nullptr, 0,
1,
allocationPointer, reductionPointer,
nullptr, nullptr), LIBND4J_TYPES);
nullptr, nullptr), LIBND4J_TYPES, INDEXING_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)

View File

@ -80,14 +80,14 @@ namespace nd4j {
};
template <typename X>
template <typename X, typename Z>
class ND4J_EXPORT IndexReductionLoops {
private:
public:
static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams);
static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams);
template <typename OpType>
static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams);
static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams);
};

View File

@ -24,10 +24,10 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////////
template <typename X>
template <typename X, typename Z>
template <typename OpType>
void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
Nd4jLong* z, Nd4jLong* zShapeInfo,
void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
Z* z, Nd4jLong* zShapeInfo,
Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets,
X* extraParams) {
@ -62,7 +62,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i] = indexValue.index;
z[i] = (Z) indexValue.index;
}
}
break;
@ -80,7 +80,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i * zEws] = indexValue.index;
z[i * zEws] = (Z) indexValue.index;
}
}
break;
@ -98,7 +98,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i] = indexValue.index;
z[i] = (Z) indexValue.index;
}
}
break;
@ -122,7 +122,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
}
z[i] = indexValue.index;
z[i] = (Z) indexValue.index;
}
}
break;
@ -148,7 +148,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
}
z[i] = indexValue.index;
z[i] = (Z) indexValue.index;
}
}
break;
@ -176,7 +176,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
}
z[i] = indexValue.index;
z[i] = (Z) indexValue.index;
}
}
break;
@ -206,7 +206,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
}
z[i] = indexValue.index;
z[i] = (Z) indexValue.index;
}
}
break;
@ -227,7 +227,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ);
z[zOffset] = indexValue.index;
z[zOffset] = (Z) indexValue.index;
}
}
break;
@ -248,7 +248,7 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i * zEws] = indexValue.index;
z[i * zEws] = (Z) indexValue.index;
}
}
break;
@ -272,18 +272,19 @@ void nd4j::IndexReductionLoops<X>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ);
z[zOffset] = indexValue.index;
z[zOffset] = (Z) indexValue.index;
}
}
}
}
template <typename X>
void nd4j::IndexReductionLoops<X>::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) {
template <typename X, typename Y>
void nd4j::IndexReductionLoops<X, Y>::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* vz, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Y *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
DISPATCH_BY_OPNUM_T(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS);
DISPATCH_BY_OPNUM_TT(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS);
}
BUILD_SINGLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES, INDEXING_TYPES);

View File

@ -31,26 +31,27 @@ namespace functions {
namespace indexreduce {
////////////////////////////////////////////////////////////////////////
template <typename X> Nd4jLong IndexReduce<X>::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS);
template <typename X, typename Y>
Nd4jLong IndexReduce<X,Y>::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
void IndexReduce<X>::exec(const int opNum,
template <typename X, typename Y>
void IndexReduce<X,Y>::exec(const int opNum,
void *x, Nd4jLong *xShapeInfo,
void *extraParams,
Nd4jLong *z, Nd4jLong *zShapeInfo,
void *z, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
template <typename X, typename Y>
template<typename OpType>
Nd4jLong IndexReduce<X>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) {
Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -105,15 +106,16 @@ Nd4jLong IndexReduce<X>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextra
////////////////////////////////////////////////////////////////////////
template <typename X>
template <typename X, typename Z>
template<typename OpType>
void IndexReduce<X>::exec(void *vx, Nd4jLong *xShapeInfo,
void IndexReduce<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
Nd4jLong *z, Nd4jLong *zShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong zLen = shape::length(zShapeInfo);
@ -124,12 +126,12 @@ void IndexReduce<X>::exec(void *vx, Nd4jLong *xShapeInfo,
const auto indexValue = OpType::startingIndexValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < zLen; i++)
z[i] = indexValue.index;;
z[i] = (Z) indexValue.index;;
return;
}
if(shape::isScalar(zShapeInfo)) {
z[0] = execScalar<OpType>(x,xShapeInfo,extraParams);
z[0] = (Z) execScalar<OpType>(x,xShapeInfo,extraParams);
return;
}
@ -146,11 +148,11 @@ void IndexReduce<X>::exec(void *vx, Nd4jLong *xShapeInfo,
tadOffsets = tadPack.primaryOffsets();
}
nd4j::IndexReductionLoops<X>::template loopIndexReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
nd4j::IndexReductionLoops<X,Z>::template loopIndexReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES);
}
}

View File

@ -29,37 +29,37 @@
using namespace simdOps;
template <typename T>
template <typename X, typename Z>
static __global__ void simpleIndexReduceGeneric(const int op,
void *dx,
Nd4jLong *xShapeInfo, int xRank,
void *extraParams,
Nd4jLong *result,
void *result,
Nd4jLong *resultShapeInfo, int zRank,
int *dimension,
int dimensionLength,
int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
functions::indexreduce::IndexReduce<T>::transform(op,dx,xShapeInfo,extraParams,result,resultShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets);
functions::indexreduce::IndexReduce<X, Z>::transform(op,dx,xShapeInfo,extraParams,result,resultShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets);
}
namespace functions {
namespace indexreduce {
template <typename T>
_CUDA_H void IndexReduce<T>::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream,
template <typename X, typename Z>
_CUDA_H void IndexReduce<X,Z>::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream,
const int opNum,
void *dx, Nd4jLong *xShapeInfo,
int xRank,
void *extraParams,
Nd4jLong *result, Nd4jLong *resultShapeInfo,
void *result, Nd4jLong *resultShapeInfo,
int zRank,
int *dimension, int dimensionLength,
int postProcessOrNot,
int *allocationBuffer, void *reductionBuffer,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
simpleIndexReduceGeneric<T><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(opNum,
simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(opNum,
dx, xShapeInfo, xRank,
extraParams,
result, resultShapeInfo, 0,
@ -67,13 +67,11 @@ namespace functions {
1,
allocationBuffer, reductionBuffer,
tadOnlyShapeInfo, tadOffsets);
nd4j::DebugHelper::checkErrorCode(stream, "execIndexReduceScalar(...) failed");
}
template <typename T>
_CUDA_H void IndexReduce<T>::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
simpleIndexReduceGeneric<T><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(
template <typename X, typename Z>
_CUDA_H void IndexReduce<X, Z>::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
simpleIndexReduceGeneric<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(
opNum,
dx,
xShapeInfo, xRank,
@ -83,8 +81,6 @@ namespace functions {
dimension,
dimensionLength,
1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets);
DEBUG_KERNEL(stream, opNum);
}
// This is the un-specialized struct. Note that we prevent instantiation of this
@ -122,14 +118,14 @@ namespace functions {
}
};
template <typename T>
template <typename X, typename Z>
template <typename OpType>
__device__ void IndexReduce<T>::aggregatePartials(IndexValue<T> **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) {
__device__ void IndexReduce<X, Z>::aggregatePartials(IndexValue<X> **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) {
// start the shared memory loop on the next power of 2 less
// than the block size. If block size is not a power of 2,
// accumulate the intermediate sums in the remainder range.
auto extraParams = static_cast<T*>(vextraParams);
IndexValue<T> *sPartials = *sPartialsRef;
auto extraParams = static_cast<X*>(vextraParams);
IndexValue<X> *sPartials = *sPartialsRef;
Nd4jLong floorPow2 = blockDim.x;
if (floorPow2 & (floorPow2 - 1)) {
@ -138,8 +134,8 @@ namespace functions {
}
if (tid >= floorPow2) {
IndexValue<T> prev = sPartials[tid - floorPow2];
IndexValue<T> curr = sPartials[tid];
IndexValue<X> prev = sPartials[tid - floorPow2];
IndexValue<X> curr = sPartials[tid];
sPartials[tid - floorPow2] = OpType::update(prev,curr,extraParams);
}
__syncthreads();
@ -147,21 +143,21 @@ namespace functions {
for (int activeThreads = floorPow2 >> 1;activeThreads; activeThreads >>= 1) {
if (tid < activeThreads && tid + activeThreads < numElements) {
IndexValue<T> curr = sPartials[tid];
IndexValue<T> next = sPartials[tid + activeThreads];
IndexValue<X> curr = sPartials[tid];
IndexValue<X> next = sPartials[tid + activeThreads];
sPartials[tid] = OpType::update(curr,next,extraParams);
}
__syncthreads();
}
}
template <typename X>
__device__ void IndexReduce<X>::transform(
template <typename X, typename Y>
__device__ void IndexReduce<X, Y>::transform(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
Nd4jLong *result,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
@ -170,15 +166,15 @@ namespace functions {
void *reductionBuffer,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, result, resultShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, resultShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
}
template <typename T>
template <typename X, typename Z>
template <typename OpType>
__device__ void IndexReduce<T>::transform(void *vdx, Nd4jLong *xShapeInfo,
__device__ void IndexReduce<X, Z>::transform(void *vdx, Nd4jLong *xShapeInfo,
void *vextraParams,
Nd4jLong *result, Nd4jLong *resultShapeInfo,
void *vresult, Nd4jLong *resultShapeInfo,
int *dimension, int dimensionLength,
int postProcessOrNot,
int *allocationBuffer, void *vreductionBuffer,
@ -186,18 +182,19 @@ namespace functions {
/**int
* Gpu information for the problem
*/
auto dx = static_cast<T*>(vdx);
auto extraParams = static_cast<T*>(vextraParams);
auto reductionBuffer = static_cast<T*>(vreductionBuffer);
auto dx = reinterpret_cast<X*>(vdx);
auto result = reinterpret_cast<Z*>(vresult);
auto extraParams = static_cast<X*>(vextraParams);
auto reductionBuffer = static_cast<X*>(vreductionBuffer);
auto order = shape::order(xShapeInfo);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ volatile int resultScalar;
//shared memory space for storing intermediate results
__shared__ IndexValue<T>* sPartials;
__shared__ IndexValue<X>* sPartials;
if(threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sPartials = reinterpret_cast<IndexValue<T>*>(shmem);
sPartials = reinterpret_cast<IndexValue<X>*>(shmem);
}
__syncthreads();
@ -210,7 +207,7 @@ namespace functions {
//only compute the tad indexes once
IndexValue <T> reduction = OpType::startingIndexValue(dx);
IndexValue<X> reduction = OpType::startingIndexValue(dx);
if (threadIdx.x == 0) {
if (resultShapeInfo != nullptr)
@ -255,7 +252,7 @@ namespace functions {
for(int i = threadIdx.x;i < tadLength; i += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
IndexValue<T> comp {dx[xOffset], i};
IndexValue<X> comp {dx[xOffset], i};
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams);
}
@ -264,7 +261,7 @@ namespace functions {
__syncthreads();
if (threadIdx.x == 0) {
result[r] = sPartials[threadIdx.x].index;
result[r] = (Z) sPartials[threadIdx.x].index;
}
__syncthreads();
}
@ -276,7 +273,7 @@ namespace functions {
sPartials[threadIdx.x] = OpType::startingIndexValue(dx);
for (int x = threadIdx.x; x < tadLength; x+= blockDim.x) {
IndexValue<T> comp {dx[tadOffsetForBlock + x * tadEWS], x};
IndexValue<X> comp {dx[tadOffsetForBlock + x * tadEWS], x};
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams);
}
@ -285,7 +282,7 @@ namespace functions {
__syncthreads();
if (threadIdx.x == 0) {
result[i] = sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams);
result[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams);
}
__syncthreads();
}
@ -296,14 +293,14 @@ namespace functions {
if(xElementWiseStride >= 1 && order == 'c') {
for(Nd4jLong i = tid;i < n; i += (blockDim.x * gridDim.x)) {
IndexValue <T> indexVal = {dx[i * xElementWiseStride], i};
IndexValue<X> indexVal = {dx[i * xElementWiseStride], i};
reduction = OpType::update(reduction, indexVal, extraParams);
}
} else {
for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) {
auto offset = shape::getIndexOffset(i, xShapeInfo, n);
IndexValue <T> indexVal = {dx[offset], i};
IndexValue<X> indexVal = {dx[offset], i};
reduction = OpType::update(reduction, indexVal, extraParams);
}
}
@ -320,7 +317,7 @@ namespace functions {
unsigned int *tc = (unsigned int *) reductionBuffer;
tid = threadIdx.x;
if (threadIdx.x == 0) {
auto pBuffer = reinterpret_cast<IndexValue<T> *>(reductionBuffer);
auto pBuffer = reinterpret_cast<IndexValue<X> *>(reductionBuffer);
pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index};
}
__threadfence();
@ -335,7 +332,7 @@ namespace functions {
if (amLast) {
tc[16384] = 0;
IndexValue<T> *pBuffer = (IndexValue<T> *) reductionBuffer;
IndexValue<X> *pBuffer = (IndexValue<X> *) reductionBuffer;
sPartials[threadIdx.x] = OpType::startingIndexValue(dx);
@ -348,14 +345,14 @@ namespace functions {
__syncthreads();
if (tid == 0) {
result[0] = sPartials[0].index;
result[0] = (Z) sPartials[0].index;
}
}
} else {
if (tid == 0) {
auto tc = reinterpret_cast<unsigned int *>(reductionBuffer);
tc[16384] = 0;
result[0] = sPartials[0].index;
result[0] = (Z) sPartials[0].index;
}
}
@ -365,30 +362,30 @@ namespace functions {
template <typename T>
Nd4jLong IndexReduce<T>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
template <typename X, typename Z>
Nd4jLong IndexReduce<X,Z>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
return 0;
}
template <typename T>
void IndexReduce<T>::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
template <typename X, typename Z>
void IndexReduce<X,Z>::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
}
template <typename T>
template <typename X, typename Z>
template<typename OpType>
Nd4jLong IndexReduce<T>:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) {
Nd4jLong IndexReduce<X,Z>:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) {
return 0;
}
template <typename T>
template <typename X, typename Z>
template<typename OpType>
_CUDA_H void IndexReduce<T>::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
_CUDA_H void IndexReduce<X,Z>::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES);
}
}

View File

@ -52,35 +52,35 @@
namespace functions {
namespace indexreduce {
template<typename T>
template<typename X, typename Z>
class IndexReduce {
public:
#ifdef __CUDACC__
static __device__ void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int *dimension,int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset);
static __device__ void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int *dimension,int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset);
template<typename OpType>
static __device__ void aggregatePartials(IndexValue<T> **sPartialsRef, Nd4jLong tid, Nd4jLong numElements,void *extraParams);
static __device__ void aggregatePartials(IndexValue<X> **sPartialsRef, Nd4jLong tid, Nd4jLong numElements,void *extraParams);
template<typename OpType>
static __device__ void transform(void *dx, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
static __device__ void transform(void *dx, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
#endif
static Nd4jLong execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams);
static void exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset);
static void exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset);
template<typename OpType>
static _CUDA_H Nd4jLong execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams);
template<typename OpType>
static _CUDA_H void exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset);
static _CUDA_H void exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset);
};
}
}

View File

@ -87,7 +87,7 @@ namespace nd4j {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(0, nd4j::DataType::ANY)
->setAllowedOutputTypes(1, {ALL_INTS});
->setAllowedOutputTypes(1, {ALL_INDICES});
}
}
}

View File

@ -42,10 +42,11 @@ namespace helpers {
Nd4jLong listDiffCount(nd4j::LaunchContext * context, NDArray* values, NDArray* keep) {
auto xType = values->dataType();
values->syncToHost();
keep->syncToHost();
NDArray::preparePrimaryUse({},{values, keep});
BUILD_SINGLE_SELECTOR(xType, return listDiffCount_, (values, keep), LIBND4J_TYPES);
NDArray::registerPrimaryUse({},{values, keep});
}
BUILD_SINGLE_TEMPLATE(template Nd4jLong listDiffCount_, (NDArray* values, NDArray* keep);, LIBND4J_TYPES);
@ -97,16 +98,7 @@ namespace helpers {
int listDiffFunctor(nd4j::LaunchContext * context, NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2) {
auto xType = values->dataType();
values->syncToHost();
if (keep != nullptr)
keep->syncToHost();
if (output1 != nullptr)
output1->syncToHost();
if (output2 != nullptr)
output2->syncToHost();
NDArray::preparePrimaryUse({output1, output2}, {values, keep});
int result = 0;
@ -118,14 +110,7 @@ namespace helpers {
throw std::runtime_error("ListDiff: Only integer and floating point data types are supported");
}
if (keep != nullptr)
keep->syncToDevice();
if (output1 != nullptr)
output1->syncToDevice();
if (output2 != nullptr)
output2->syncToDevice();
NDArray::registerPrimaryUse({output1, output2}, {values, keep});
return result;
}

View File

@ -3746,7 +3746,7 @@ namespace simdOps {
};
template <typename X>
template <typename X, typename Z>
class IndexAbsoluteMax {
public:
static _CUDA_HD inline functions::indexreduce::IndexValue<X> op(functions::indexreduce::IndexValue<X> val, X *extraParams) {
@ -3799,7 +3799,7 @@ namespace simdOps {
}
};
template <typename X>
template <typename X, typename Z>
class FirstIndex {
public:
static _CUDA_HD inline functions::indexreduce::IndexValue<X> op(functions::indexreduce::IndexValue<X> val, X *extraParams) {
@ -3861,7 +3861,7 @@ namespace simdOps {
};
template <typename X>
template <typename X, typename Z>
class LastIndex {
public:
static _CUDA_HD inline functions::indexreduce::IndexValue<X> op(functions::indexreduce::IndexValue<X> val, X *extraParams) {
@ -3920,7 +3920,7 @@ namespace simdOps {
};
template <typename X>
template <typename X, typename Z>
class IndexMax {
public:
@ -3974,7 +3974,7 @@ namespace simdOps {
};
template <typename X>
template <typename X, typename Z>
class IndexAbsoluteMin {
public:
static _CUDA_HD inline functions::indexreduce::IndexValue<X> op(
@ -4030,7 +4030,7 @@ namespace simdOps {
};
template <typename X>
template <typename X, typename Z>
class IndexMin {
public:
static _CUDA_HD inline functions::indexreduce::IndexValue<X> op(

View File

@ -226,6 +226,21 @@ public class CudaExecutioner extends DefaultOpExecutioner {
*/
protected INDArray naiveExec(ReduceOp op, int... dimension) {
long st = profilingConfigurableHookIn(op);
if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){
//Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
//Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
if(op.z() != null){
Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." +
" Got: x=%ndShape, z=%ndShape", op.x(), op.z());
op.z().assign(op.x());
return op.z();
} else {
op.setZ(op.x().dup());
return op.z();
}
}
INDArray ret = op.z();
checkForCompression(op);
@ -482,6 +497,20 @@ public class CudaExecutioner extends DefaultOpExecutioner {
public INDArray exec(ReduceOp op) {
checkForCompression(op);
if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){
//Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
//Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
if(op.z() != null){
Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." +
" Got: x=%ndShape, z=%ndShape", op.x(), op.z());
op.z().assign(op.x());
return op.z();
} else {
op.setZ(op.x().dup());
return op.z();
}
}
val dimension = op.dimensions().toIntVector();
if (extraz.get() == null)
@ -890,6 +919,22 @@ public class CudaExecutioner extends DefaultOpExecutioner {
protected CudaContext invoke(ReduceOp op, int[] dimension) {
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){
//Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
//Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
if(op.z() != null){
Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." +
" Got: x=%ndShape, z=%ndShape", op.x(), op.z());
op.z().assign(op.x());
return context;
} else {
op.setZ(op.x().dup());
return context;
}
}
long st = profilingConfigurableHookIn(op);
checkForCompression(op);
@ -913,8 +958,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension)
+ " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
lastOp.set(op.opName());

View File

@ -733,6 +733,46 @@ public class CustomOpsTests extends BaseNd4jTest {
}
}
@Test
public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
INDArray y = Nd4j.createFromArray(3, 1);
INDArray out = Nd4j.create(DataType.INT, 2);
INDArray outIdx = Nd4j.create(DataType.INT, 2);
Nd4j.exec(DynamicCustomOp.builder("listdiff")
.addInputs(x, y)
.addOutputs(out, outIdx)
.build());
INDArray exp = Nd4j.createFromArray(0, 2);
assertEquals(exp, out); //Values in x not in y
assertEquals(exp, outIdx); //Indices of the values in x not in y
}
@Test
public void testTopK1(){
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
INDArray k = Nd4j.scalar(1);
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
INDArray outIdx = Nd4j.create(DataType.INT, 1);
Nd4j.exec(DynamicCustomOp.builder("top_k")
.addInputs(x, k)
.addOutputs(outValue, outIdx)
.addBooleanArguments(false) //not sorted
.addIntegerArguments(1)
.build());
INDArray expValue = Nd4j.createFromArray(10.0);
INDArray expIdx = Nd4j.createFromArray(3);
assertEquals(expValue, outValue);
assertEquals(expIdx, outIdx);
}
@Test
public void testMaxPool2Dbp_1() {
val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN);

View File

@ -25,6 +25,7 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -299,6 +300,15 @@ public class EmptyTests extends BaseNd4jTest {
assertNotNull(result[0].shapeInfoDataBuffer().asLong());
}
@Test
public void testAllEmptyReduce(){
INDArray x = Nd4j.createFromArray(true, true, true);
val all = new All(x);
all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction)
INDArray out = Nd4j.exec(all);
assertEquals(x, out);
}
@Override
public char ordering() {
return 'c';