[WIP] More fixes (#190)
* Refactored kernels for segment_max/min/sum ops. * Refactored segment_prod kernels. * Refactored segment_prod kernels. * DynamicPartition test Signed-off-by: raver119 <raver119@gmail.com> * Addede linear test for dynamic_partition op. * Refactored test with int datatype. * some logging Signed-off-by: raver119 <raver119@gmail.com> * some logging Signed-off-by: raver119 <raver119@gmail.com> * some logging Signed-off-by: raver119 <raver119@gmail.com> * dynamicPartition fix Signed-off-by: raver119 <raver119@gmail.com> * get rid of some logging Signed-off-by: raver119 <raver119@gmail.com> * one more test for dynamic_stitch Signed-off-by: raver119 <raver119@gmail.com> * one more test for dynamic_stitch Signed-off-by: raver119 <raver119@gmail.com> * empty check for stitch Signed-off-by: raver119 <raver119@gmail.com> * minor print changes Signed-off-by: raver119 <raver119@gmail.com>master
parent
3157ec110c
commit
f4860574d7
|
@ -50,8 +50,8 @@ namespace nd4j {
|
||||||
auto zShapeInfo = zShapeInfos[o];
|
auto zShapeInfo = zShapeInfos[o];
|
||||||
auto zLength = shape::length(zShapeInfo);
|
auto zLength = shape::length(zShapeInfo);
|
||||||
|
|
||||||
// iLimit should be
|
// iLimit should be multiple of blockDim.x
|
||||||
auto iLimit = iLength <= blockIdx.x ? blockIdx.x : (iLength + (blockIdx.x - (iLength % blockIdx.x)));
|
auto iLimit = iLength <= blockDim.x ? blockDim.x : (iLength + (blockDim.x - (iLength % blockDim.x)));
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
|
||||||
for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) {
|
for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) {
|
||||||
|
@ -75,8 +75,9 @@ namespace nd4j {
|
||||||
|
|
||||||
// doing actual update
|
// doing actual update
|
||||||
if (e < iLength)
|
if (e < iLength)
|
||||||
if (trueIndices[threadIdx.x] >= 0)
|
if (trueIndices[threadIdx.x] >= 0) {
|
||||||
z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo, xLength)];
|
z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo, xLength)];
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
@ -148,13 +149,12 @@ namespace nd4j {
|
||||||
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
||||||
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
||||||
|
|
||||||
dynamicPartitionTadKernel<X,Y><<<256, 512, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
|
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
auto numThreads = 256;
|
auto numThreads = 256;
|
||||||
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
|
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
|
||||||
|
|
||||||
|
|
||||||
std::vector<void *> outBuffers;
|
std::vector<void *> outBuffers;
|
||||||
std::vector<Nd4jLong *> outShapes;
|
std::vector<Nd4jLong *> outShapes;
|
||||||
|
|
||||||
|
@ -203,6 +203,9 @@ namespace nd4j {
|
||||||
auto indices = reinterpret_cast<Y*>(vindices[e]);
|
auto indices = reinterpret_cast<Y*>(vindices[e]);
|
||||||
auto iShapeInfo = iShapeInfos[e];
|
auto iShapeInfo = iShapeInfos[e];
|
||||||
|
|
||||||
|
if (shape::isEmpty(iShapeInfo))
|
||||||
|
continue;
|
||||||
|
|
||||||
auto iLength = shape::length(iShapeInfo);
|
auto iLength = shape::length(iShapeInfo);
|
||||||
auto zLength = shape::length(zTadShapeInfo);
|
auto zLength = shape::length(zTadShapeInfo);
|
||||||
|
|
||||||
|
@ -310,8 +313,9 @@ namespace nd4j {
|
||||||
|
|
||||||
NDArray::registerSpecialUse({}, {indices, input});
|
NDArray::registerSpecialUse({}, {indices, input});
|
||||||
|
|
||||||
for (auto v:outputList)
|
for (auto v:outputList) {
|
||||||
v->tickWriteDevice();
|
v->tickWriteDevice();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -40,19 +40,16 @@ namespace nd4j {
|
||||||
static __global__ void
|
static __global__ void
|
||||||
segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||||
void *output, Nd4jLong *outputShape) {
|
void *output, Nd4jLong *outputShape) {
|
||||||
__shared__
|
__shared__ T *val;
|
||||||
T *val;
|
__shared__ Nd4jLong xLen, zLen, zIndex;
|
||||||
__shared__
|
__shared__ T *x;
|
||||||
Nd4jLong xLen, zLen, segment, zIndex;
|
__shared__ T *z;
|
||||||
__shared__
|
|
||||||
T *x;
|
|
||||||
__shared__
|
|
||||||
T *z;
|
|
||||||
__shared__ int threadsPerSegment, start, finish;
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
auto segment = blockIdx.x;
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
segment = blockIdx.x / threadsPerSegment;
|
// segment = blockIdx.x / threadsPerSegment;
|
||||||
x = reinterpret_cast<T *>(input);
|
x = reinterpret_cast<T *>(input);
|
||||||
z = reinterpret_cast<T *>(output);
|
z = reinterpret_cast<T *>(output);
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
|
@ -83,19 +80,14 @@ namespace nd4j {
|
||||||
unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||||
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||||
Nd4jLong *outputShape) {
|
Nd4jLong *outputShape) {
|
||||||
__shared__
|
__shared__ T *val;
|
||||||
T *val;
|
__shared__ Nd4jLong xLen, zLen, zIndex;
|
||||||
__shared__
|
__shared__ T *x;
|
||||||
Nd4jLong xLen, zLen, segment, zIndex;
|
__shared__ T *z;
|
||||||
__shared__
|
__shared__ I *y; //int threadsPerSegment, start, finish;
|
||||||
T *x;
|
auto segment = blockIdx.x;
|
||||||
__shared__
|
|
||||||
T *z;
|
|
||||||
__shared__
|
|
||||||
I *y; //int threadsPerSegment, start, finish;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
segment = blockIdx.x;
|
|
||||||
x = reinterpret_cast<T *>(input);
|
x = reinterpret_cast<T *>(input);
|
||||||
z = reinterpret_cast<T *>(output);
|
z = reinterpret_cast<T *>(output);
|
||||||
y = reinterpret_cast<I *>(indices);
|
y = reinterpret_cast<I *>(indices);
|
||||||
|
@ -127,9 +119,10 @@ namespace nd4j {
|
||||||
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) {
|
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) {
|
||||||
|
|
||||||
__shared__ T* val;
|
__shared__ T* val;
|
||||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
__shared__ Nd4jLong len, zIndex, total;
|
||||||
__shared__ T* z;
|
__shared__ T* z;
|
||||||
__shared__ int start, finish;
|
__shared__ int start, finish;
|
||||||
|
__shared__ I segment;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
|
@ -143,20 +136,22 @@ namespace nd4j {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
auto idx = blockIdx.x;
|
auto idx = blockIdx.x;
|
||||||
if (blockIdx.x <= total) {
|
if (idx <= total) {
|
||||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||||
if (blockIdx.x == start) {
|
if (blockIdx.x == start) {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
z[zIndex] = x[xIndex];
|
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||||
|
//z[zIndex] = x[xIndex];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
if (lengths[segment])
|
||||||
|
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -168,6 +163,7 @@ namespace nd4j {
|
||||||
//int numClasses = output->sizeAt(0);
|
//int numClasses = output->sizeAt(0);
|
||||||
// if input is a vector: (as if in doc sample)
|
// if input is a vector: (as if in doc sample)
|
||||||
//Nd4jLong idx = indices->e<Nd4jLong>(0);
|
//Nd4jLong idx = indices->e<Nd4jLong>(0);
|
||||||
|
output->assign(-DataTypeUtils::infOrMax<T>());
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
indices->syncToHost();
|
indices->syncToHost();
|
||||||
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
|
@ -211,6 +207,8 @@ namespace nd4j {
|
||||||
static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
|
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
@ -243,6 +241,7 @@ namespace nd4j {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
|
output->nullify();
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
NDArray::registerSpecialUse({output}, {input, indices});
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,19 +38,16 @@ namespace helpers {
|
||||||
static __global__ void
|
static __global__ void
|
||||||
segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||||
void *output, Nd4jLong *outputShape) {
|
void *output, Nd4jLong *outputShape) {
|
||||||
__shared__
|
__shared__ T *val;
|
||||||
T *val;
|
__shared__ Nd4jLong xLen, zLen, zIndex;
|
||||||
__shared__
|
__shared__ T *x;
|
||||||
Nd4jLong xLen, zLen, segment, zIndex;
|
__shared__ T *z;
|
||||||
__shared__
|
|
||||||
T *x;
|
|
||||||
__shared__
|
|
||||||
T *z;
|
|
||||||
__shared__ int threadsPerSegment, start, finish;
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
auto segment = blockIdx.x;
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
segment = blockIdx.x / threadsPerSegment;
|
// segment = blockIdx.x / threadsPerSegment;
|
||||||
x = reinterpret_cast<T *>(input);
|
x = reinterpret_cast<T *>(input);
|
||||||
z = reinterpret_cast<T *>(output);
|
z = reinterpret_cast<T *>(output);
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
|
@ -123,12 +120,12 @@ namespace helpers {
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||||
__shared__ T* val;
|
__shared__ T* val;
|
||||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
__shared__ Nd4jLong len, zIndex, total;
|
||||||
__shared__ T* z;
|
__shared__ T* z;
|
||||||
__shared__ int threadsPerSegment, start, finish;
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
auto segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
|
||||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
len = shape::length(inputTads);
|
len = shape::length(inputTads);
|
||||||
start = starts[segment];
|
start = starts[segment];
|
||||||
|
@ -145,14 +142,15 @@ namespace helpers {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
z[zIndex] = x[xIndex];
|
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
// if (lengths[indices[idx]])
|
||||||
|
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,7 +163,7 @@ namespace helpers {
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
|
||||||
|
@ -193,6 +191,7 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
|
output->nullify();
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
NDArray::registerSpecialUse({output}, {input, indices});
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
@ -207,6 +206,7 @@ namespace helpers {
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||||
|
@ -236,6 +236,7 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
|
output->nullify();
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
NUMERIC_TYPES, INDEXING_TYPES);
|
NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
NDArray::registerSpecialUse({output}, {input, indices});
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
|
|
|
@ -146,14 +146,15 @@ namespace helpers {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
z[zIndex] = x[xIndex];
|
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
if (lengths[segment] > 0)
|
||||||
|
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -166,7 +167,7 @@ namespace helpers {
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
output->assign(1);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
|
||||||
|
@ -373,6 +374,7 @@ namespace helpers {
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||||
unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
|
|
@ -121,12 +121,12 @@ namespace helpers {
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||||
__shared__ T* val;
|
__shared__ T* val;
|
||||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
__shared__ Nd4jLong len, zIndex, total;
|
||||||
__shared__ T* z;
|
__shared__ T* z;
|
||||||
__shared__ int threadsPerSegment, start, finish;
|
__shared__ int start, finish;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
auto segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
len = shape::length(inputTads);
|
len = shape::length(inputTads);
|
||||||
start = starts[segment];
|
start = starts[segment];
|
||||||
|
@ -143,14 +143,14 @@ namespace helpers {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
z[zIndex] = x[xIndex];
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
if (lengths[segment])
|
if (lengths[indices[idx]])
|
||||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -191,6 +191,7 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
|
output->nullify();
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
NDArray::registerSpecialUse({output}, {input, indices});
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
@ -232,6 +233,7 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
|
output->nullify();
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
NUMERIC_TYPES, INDEXING_TYPES);
|
NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
NDArray::registerSpecialUse({output}, {input, indices});
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
|
|
|
@ -602,6 +602,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
|
||||||
// NDArray VT = outArrs[2]->transpose();
|
// NDArray VT = outArrs[2]->transpose();
|
||||||
NDArray* V = outArrs[2];
|
NDArray* V = outArrs[2];
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({S, U, V}, {x});
|
||||||
|
|
||||||
if(x->rankOf() == 2) {
|
if(x->rankOf() == 2) {
|
||||||
// svdQR(context, x, S, U, VT, fullUV, calcUV);
|
// svdQR(context, x, S, U, VT, fullUV, calcUV);
|
||||||
svdJcb(context, x, S, U, V, fullUV, calcUV);
|
svdJcb(context, x, S, U, V, fullUV, calcUV);
|
||||||
|
@ -631,6 +633,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
|
||||||
delete tadsV;
|
delete tadsV;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({S, U, V}, {x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1920,6 +1920,50 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) {
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
/* @Test
|
||||||
|
public void testDynamicPartition(){
|
||||||
|
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||||
|
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||||
|
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
|
||||||
|
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
|
||||||
|
.addIntegerArguments(3) //3 partitions
|
||||||
|
.addInputs(data, partitions).build());
|
||||||
|
|
||||||
|
INDArray exp0 = Nd4j.createFromArray(2, 0);
|
||||||
|
INDArray exp1 = Nd4j.createFromArray(2);
|
||||||
|
INDArray exp2 = Nd4j.createFromArray(1);
|
||||||
|
|
||||||
|
assertEquals(exp0, out[0]); //Usually just gives [0,0]
|
||||||
|
assertEquals(exp1, out[1]);
|
||||||
|
assertEquals(exp2, out[2]);
|
||||||
|
}*/
|
||||||
|
TEST_F(DeclarableOpsTests5, DynamicPartition_01) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<int>({2,1,2,0});
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<int>({0,2,1,0});
|
||||||
|
|
||||||
|
int numPartition = 3;
|
||||||
|
std::vector<NDArray> exp( { NDArrayFactory::create<int>('c', {2}, {2, 0}),
|
||||||
|
NDArrayFactory::create<int>('c', {1}, {2}),
|
||||||
|
NDArrayFactory::create<int>('c', {1}, {1})});
|
||||||
|
|
||||||
|
nd4j::ops::dynamic_partition op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {numPartition});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4
|
||||||
|
|
||||||
|
for (int e = 0; e < result->size(); e++) {
|
||||||
|
auto output = result->at(e);
|
||||||
|
// output->printShapeInfo("Output shape> ");
|
||||||
|
// output->printIndexedBuffer("Output data> ");
|
||||||
|
ASSERT_TRUE(exp[e].isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp[e].equalsTo(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, DynamicPartition_1) {
|
TEST_F(DeclarableOpsTests5, DynamicPartition_1) {
|
||||||
|
|
||||||
|
@ -2031,6 +2075,38 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) {
|
||||||
|
auto i0 = NDArrayFactory::create<int>('c', {2}, {2, 3});
|
||||||
|
auto i1 = NDArrayFactory::empty<int>();
|
||||||
|
auto i2 = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||||
|
|
||||||
|
auto d0 = NDArrayFactory::create<double>('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126});
|
||||||
|
auto d1 = NDArrayFactory::empty<double>();
|
||||||
|
auto d2 = NDArrayFactory::create<double>('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326});
|
||||||
|
|
||||||
|
nd4j::ops::dynamic_stitch op;
|
||||||
|
auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) {
|
||||||
|
auto i0 = NDArrayFactory::create<int>('c', {2}, {2, 3});
|
||||||
|
auto i1 = NDArrayFactory::create<int>('c', {0});
|
||||||
|
auto i2 = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||||
|
|
||||||
|
auto d0 = NDArrayFactory::create<double>('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126});
|
||||||
|
auto d1 = NDArrayFactory::create<double>('c', {0, 5});
|
||||||
|
auto d2 = NDArrayFactory::create<double>('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326});
|
||||||
|
|
||||||
|
nd4j::ops::dynamic_stitch op;
|
||||||
|
auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, DynamicStitch_1) {
|
TEST_F(DeclarableOpsTests5, DynamicStitch_1) {
|
||||||
|
|
|
@ -1728,6 +1728,24 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDynamicPartition(){
|
||||||
|
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||||
|
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||||
|
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
|
||||||
|
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
|
||||||
|
.addIntegerArguments(3) //3 partitions
|
||||||
|
.addInputs(data, partitions).build());
|
||||||
|
|
||||||
|
INDArray exp0 = Nd4j.createFromArray(2, 0);
|
||||||
|
INDArray exp1 = Nd4j.createFromArray(2);
|
||||||
|
INDArray exp2 = Nd4j.createFromArray(1);
|
||||||
|
|
||||||
|
assertEquals(exp0, out[0]); //Usually just gives [0,0]
|
||||||
|
assertEquals(exp1, out[1]);
|
||||||
|
assertEquals(exp2, out[2]);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListDiff(){
|
public void testListDiff(){
|
||||||
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
||||||
|
|
|
@ -344,6 +344,8 @@ public class TFGraphTestAllHelper {
|
||||||
System.out.println("Pass: " + varName);
|
System.out.println("Pass: " + varName);
|
||||||
} else {
|
} else {
|
||||||
System.out.println("FAIL: " + varName);
|
System.out.println("FAIL: " + varName);
|
||||||
|
System.out.println("TF:\n" + tfValue);
|
||||||
|
System.out.println("SD:\n" + sdVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,8 +180,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
|
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
||||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
||||||
throw t;
|
throw t;
|
||||||
|
|
Loading…
Reference in New Issue