[WIP] More fixes (#190)
* Refactored kernels for segment_max/min/sum ops. * Refactored segment_prod kernels. * Refactored segment_prod kernels. * DynamicPartition test Signed-off-by: raver119 <raver119@gmail.com> * Addede linear test for dynamic_partition op. * Refactored test with int datatype. * some logging Signed-off-by: raver119 <raver119@gmail.com> * some logging Signed-off-by: raver119 <raver119@gmail.com> * some logging Signed-off-by: raver119 <raver119@gmail.com> * dynamicPartition fix Signed-off-by: raver119 <raver119@gmail.com> * get rid of some logging Signed-off-by: raver119 <raver119@gmail.com> * one more test for dynamic_stitch Signed-off-by: raver119 <raver119@gmail.com> * one more test for dynamic_stitch Signed-off-by: raver119 <raver119@gmail.com> * empty check for stitch Signed-off-by: raver119 <raver119@gmail.com> * minor print changes Signed-off-by: raver119 <raver119@gmail.com>master
parent
3157ec110c
commit
f4860574d7
|
@ -50,8 +50,8 @@ namespace nd4j {
|
|||
auto zShapeInfo = zShapeInfos[o];
|
||||
auto zLength = shape::length(zShapeInfo);
|
||||
|
||||
// iLimit should be
|
||||
auto iLimit = iLength <= blockIdx.x ? blockIdx.x : (iLength + (blockIdx.x - (iLength % blockIdx.x)));
|
||||
// iLimit should be multiple of blockDim.x
|
||||
auto iLimit = iLength <= blockDim.x ? blockDim.x : (iLength + (blockDim.x - (iLength % blockDim.x)));
|
||||
int cnt = 0;
|
||||
|
||||
for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) {
|
||||
|
@ -75,8 +75,9 @@ namespace nd4j {
|
|||
|
||||
// doing actual update
|
||||
if (e < iLength)
|
||||
if (trueIndices[threadIdx.x] >= 0)
|
||||
if (trueIndices[threadIdx.x] >= 0) {
|
||||
z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo, xLength)];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
@ -148,13 +149,12 @@ namespace nd4j {
|
|||
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
||||
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
||||
|
||||
dynamicPartitionTadKernel<X,Y><<<256, 512, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
|
||||
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
|
||||
|
||||
} else {
|
||||
auto numThreads = 256;
|
||||
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
|
||||
|
||||
|
||||
std::vector<void *> outBuffers;
|
||||
std::vector<Nd4jLong *> outShapes;
|
||||
|
||||
|
@ -203,6 +203,9 @@ namespace nd4j {
|
|||
auto indices = reinterpret_cast<Y*>(vindices[e]);
|
||||
auto iShapeInfo = iShapeInfos[e];
|
||||
|
||||
if (shape::isEmpty(iShapeInfo))
|
||||
continue;
|
||||
|
||||
auto iLength = shape::length(iShapeInfo);
|
||||
auto zLength = shape::length(zTadShapeInfo);
|
||||
|
||||
|
@ -310,9 +313,10 @@ namespace nd4j {
|
|||
|
||||
NDArray::registerSpecialUse({}, {indices, input});
|
||||
|
||||
for (auto v:outputList)
|
||||
for (auto v:outputList) {
|
||||
v->tickWriteDevice();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int _dynamicStitchFunctorBP(std::vector<NDArray*> const& inputs, std::vector<NDArray*> const& indices, NDArray const* gradInput, std::vector<NDArray*>& outputList){
|
||||
|
|
|
@ -40,19 +40,16 @@ namespace nd4j {
|
|||
static __global__ void
|
||||
segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||
void *output, Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__ T *val;
|
||||
__shared__ Nd4jLong xLen, zLen, zIndex;
|
||||
__shared__ T *x;
|
||||
__shared__ T *z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
auto segment = blockIdx.x;
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
// segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
extern __shared__ unsigned char shmem[];
|
||||
|
@ -83,19 +80,14 @@ namespace nd4j {
|
|||
unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||
Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__
|
||||
I *y; //int threadsPerSegment, start, finish;
|
||||
__shared__ T *val;
|
||||
__shared__ Nd4jLong xLen, zLen, zIndex;
|
||||
__shared__ T *x;
|
||||
__shared__ T *z;
|
||||
__shared__ I *y; //int threadsPerSegment, start, finish;
|
||||
auto segment = blockIdx.x;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = blockIdx.x;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
y = reinterpret_cast<I *>(indices);
|
||||
|
@ -127,9 +119,10 @@ namespace nd4j {
|
|||
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) {
|
||||
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ Nd4jLong len, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int start, finish;
|
||||
__shared__ I segment;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
|
@ -143,19 +136,21 @@ namespace nd4j {
|
|||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
if (idx <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||
//z[zIndex] = x[xIndex];
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
if (lengths[segment])
|
||||
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
|
@ -168,6 +163,7 @@ namespace nd4j {
|
|||
//int numClasses = output->sizeAt(0);
|
||||
// if input is a vector: (as if in doc sample)
|
||||
//Nd4jLong idx = indices->e<Nd4jLong>(0);
|
||||
output->assign(-DataTypeUtils::infOrMax<T>());
|
||||
auto stream = context->getCudaStream();
|
||||
indices->syncToHost();
|
||||
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
|
@ -211,6 +207,8 @@ namespace nd4j {
|
|||
static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||
output->assign(DataTypeUtils::infOrMax<T>());
|
||||
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
|
@ -243,6 +241,7 @@ namespace nd4j {
|
|||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||
output->nullify();
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, indices});
|
||||
}
|
||||
|
|
|
@ -38,19 +38,16 @@ namespace helpers {
|
|||
static __global__ void
|
||||
segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||
void *output, Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__ T *val;
|
||||
__shared__ Nd4jLong xLen, zLen, zIndex;
|
||||
__shared__ T *x;
|
||||
__shared__ T *z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
auto segment = blockIdx.x;
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
// segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
extern __shared__ unsigned char shmem[];
|
||||
|
@ -123,12 +120,12 @@ namespace helpers {
|
|||
template <typename T, typename I>
|
||||
static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ Nd4jLong len, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
auto segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
|
@ -145,13 +142,14 @@ namespace helpers {
|
|||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
// if (lengths[indices[idx]])
|
||||
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
|
@ -165,7 +163,7 @@ namespace helpers {
|
|||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
output->assign(DataTypeUtils::infOrMax<T>());
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
|
||||
|
@ -193,6 +191,7 @@ namespace helpers {
|
|||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||
output->nullify();
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, indices});
|
||||
}
|
||||
|
@ -207,6 +206,7 @@ namespace helpers {
|
|||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
output->assign(DataTypeUtils::infOrMax<T>());
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
|
@ -236,6 +236,7 @@ namespace helpers {
|
|||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||
output->nullify();
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
|
||||
NUMERIC_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, indices});
|
||||
|
|
|
@ -146,13 +146,14 @@ namespace helpers {
|
|||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
if (lengths[segment] > 0)
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
|
@ -166,7 +167,7 @@ namespace helpers {
|
|||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
output->assign(1);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
|
||||
|
@ -373,6 +374,7 @@ namespace helpers {
|
|||
template <typename T, typename I>
|
||||
static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
|
||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||
unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
|
|
|
@ -121,12 +121,12 @@ namespace helpers {
|
|||
template <typename T, typename I>
|
||||
static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ Nd4jLong len, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
__shared__ int start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
auto segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
|
@ -143,14 +143,14 @@ namespace helpers {
|
|||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
if (lengths[segment])
|
||||
if (lengths[indices[idx]])
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
|
@ -191,6 +191,7 @@ namespace helpers {
|
|||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||
output->nullify();
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, indices});
|
||||
}
|
||||
|
@ -232,6 +233,7 @@ namespace helpers {
|
|||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||
output->nullify();
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
|
||||
NUMERIC_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, indices});
|
||||
|
|
|
@ -602,6 +602,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
|
|||
// NDArray VT = outArrs[2]->transpose();
|
||||
NDArray* V = outArrs[2];
|
||||
|
||||
NDArray::prepareSpecialUse({S, U, V}, {x});
|
||||
|
||||
if(x->rankOf() == 2) {
|
||||
// svdQR(context, x, S, U, VT, fullUV, calcUV);
|
||||
svdJcb(context, x, S, U, V, fullUV, calcUV);
|
||||
|
@ -631,6 +633,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
|
|||
delete tadsV;
|
||||
}
|
||||
}
|
||||
|
||||
NDArray::registerSpecialUse({S, U, V}, {x});
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1920,6 +1920,50 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
/* @Test
|
||||
public void testDynamicPartition(){
|
||||
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
|
||||
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
|
||||
.addIntegerArguments(3) //3 partitions
|
||||
.addInputs(data, partitions).build());
|
||||
|
||||
INDArray exp0 = Nd4j.createFromArray(2, 0);
|
||||
INDArray exp1 = Nd4j.createFromArray(2);
|
||||
INDArray exp2 = Nd4j.createFromArray(1);
|
||||
|
||||
assertEquals(exp0, out[0]); //Usually just gives [0,0]
|
||||
assertEquals(exp1, out[1]);
|
||||
assertEquals(exp2, out[2]);
|
||||
}*/
|
||||
TEST_F(DeclarableOpsTests5, DynamicPartition_01) {
|
||||
|
||||
auto x = NDArrayFactory::create<int>({2,1,2,0});
|
||||
|
||||
auto y = NDArrayFactory::create<int>({0,2,1,0});
|
||||
|
||||
int numPartition = 3;
|
||||
std::vector<NDArray> exp( { NDArrayFactory::create<int>('c', {2}, {2, 0}),
|
||||
NDArrayFactory::create<int>('c', {1}, {2}),
|
||||
NDArrayFactory::create<int>('c', {1}, {1})});
|
||||
|
||||
nd4j::ops::dynamic_partition op;
|
||||
auto result = op.execute({&x, &y}, {}, {numPartition});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4
|
||||
|
||||
for (int e = 0; e < result->size(); e++) {
|
||||
auto output = result->at(e);
|
||||
// output->printShapeInfo("Output shape> ");
|
||||
// output->printIndexedBuffer("Output data> ");
|
||||
ASSERT_TRUE(exp[e].isSameShape(output));
|
||||
ASSERT_TRUE(exp[e].equalsTo(output));
|
||||
}
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests5, DynamicPartition_1) {
|
||||
|
||||
|
@ -2031,6 +2075,38 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) {
|
||||
auto i0 = NDArrayFactory::create<int>('c', {2}, {2, 3});
|
||||
auto i1 = NDArrayFactory::empty<int>();
|
||||
auto i2 = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||
|
||||
auto d0 = NDArrayFactory::create<double>('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126});
|
||||
auto d1 = NDArrayFactory::empty<double>();
|
||||
auto d2 = NDArrayFactory::create<double>('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326});
|
||||
|
||||
nd4j::ops::dynamic_stitch op;
|
||||
auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) {
|
||||
auto i0 = NDArrayFactory::create<int>('c', {2}, {2, 3});
|
||||
auto i1 = NDArrayFactory::create<int>('c', {0});
|
||||
auto i2 = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||
|
||||
auto d0 = NDArrayFactory::create<double>('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126});
|
||||
auto d1 = NDArrayFactory::create<double>('c', {0, 5});
|
||||
auto d2 = NDArrayFactory::create<double>('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326});
|
||||
|
||||
nd4j::ops::dynamic_stitch op;
|
||||
auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST_F(DeclarableOpsTests5, DynamicStitch_1) {
|
||||
|
|
|
@ -1728,6 +1728,24 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
assertEquals(exp, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicPartition(){
|
||||
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
|
||||
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
|
||||
.addIntegerArguments(3) //3 partitions
|
||||
.addInputs(data, partitions).build());
|
||||
|
||||
INDArray exp0 = Nd4j.createFromArray(2, 0);
|
||||
INDArray exp1 = Nd4j.createFromArray(2);
|
||||
INDArray exp2 = Nd4j.createFromArray(1);
|
||||
|
||||
assertEquals(exp0, out[0]); //Usually just gives [0,0]
|
||||
assertEquals(exp1, out[1]);
|
||||
assertEquals(exp2, out[2]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListDiff(){
|
||||
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
||||
|
|
|
@ -344,6 +344,8 @@ public class TFGraphTestAllHelper {
|
|||
System.out.println("Pass: " + varName);
|
||||
} else {
|
||||
System.out.println("FAIL: " + varName);
|
||||
System.out.println("TF:\n" + tfValue);
|
||||
System.out.println("SD:\n" + sdVal);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -180,8 +180,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||
|
||||
try {
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
|
||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
||||
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
||||
} catch (Throwable t){
|
||||
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
||||
throw t;
|
||||
|
|
Loading…
Reference in New Issue