[WIP] More fixes (#190)

* Refactored kernels for segment_max/min/sum ops.

* Refactored segment_prod kernels.

* Refactored segment_prod kernels.

* DynamicPartition test

Signed-off-by: raver119 <raver119@gmail.com>

* Addede linear test for dynamic_partition op.

* Refactored test with int datatype.

* some logging

Signed-off-by: raver119 <raver119@gmail.com>

* some logging

Signed-off-by: raver119 <raver119@gmail.com>

* some logging

Signed-off-by: raver119 <raver119@gmail.com>

* dynamicPartition fix

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of some logging

Signed-off-by: raver119 <raver119@gmail.com>

* one more test for dynamic_stitch

Signed-off-by: raver119 <raver119@gmail.com>

* one more test for dynamic_stitch

Signed-off-by: raver119 <raver119@gmail.com>

* empty check for stitch

Signed-off-by: raver119 <raver119@gmail.com>

* minor print changes

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-28 15:38:57 +03:00 committed by GitHub
parent 3157ec110c
commit f4860574d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 164 additions and 56 deletions

View File

@ -50,8 +50,8 @@ namespace nd4j {
auto zShapeInfo = zShapeInfos[o];
auto zLength = shape::length(zShapeInfo);
// iLimit should be
auto iLimit = iLength <= blockIdx.x ? blockIdx.x : (iLength + (blockIdx.x - (iLength % blockIdx.x)));
// iLimit should be multiple of blockDim.x
auto iLimit = iLength <= blockDim.x ? blockDim.x : (iLength + (blockDim.x - (iLength % blockDim.x)));
int cnt = 0;
for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) {
@ -75,8 +75,9 @@ namespace nd4j {
// doing actual update
if (e < iLength)
if (trueIndices[threadIdx.x] >= 0)
if (trueIndices[threadIdx.x] >= 0) {
z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo, xLength)];
}
__syncthreads();
}
@ -148,13 +149,12 @@ namespace nd4j {
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
dynamicPartitionTadKernel<X,Y><<<256, 512, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
} else {
auto numThreads = 256;
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
std::vector<void *> outBuffers;
std::vector<Nd4jLong *> outShapes;
@ -203,6 +203,9 @@ namespace nd4j {
auto indices = reinterpret_cast<Y*>(vindices[e]);
auto iShapeInfo = iShapeInfos[e];
if (shape::isEmpty(iShapeInfo))
continue;
auto iLength = shape::length(iShapeInfo);
auto zLength = shape::length(zTadShapeInfo);
@ -310,9 +313,10 @@ namespace nd4j {
NDArray::registerSpecialUse({}, {indices, input});
for (auto v:outputList)
for (auto v:outputList) {
v->tickWriteDevice();
}
}
template <typename T>
static int _dynamicStitchFunctorBP(std::vector<NDArray*> const& inputs, std::vector<NDArray*> const& indices, NDArray const* gradInput, std::vector<NDArray*>& outputList){

View File

@ -40,19 +40,16 @@ namespace nd4j {
static __global__ void
segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
void *output, Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__ T *val;
__shared__ Nd4jLong xLen, zLen, zIndex;
__shared__ T *x;
__shared__ T *z;
__shared__ int threadsPerSegment, start, finish;
auto segment = blockIdx.x;
if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
// segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
extern __shared__ unsigned char shmem[];
@ -83,19 +80,14 @@ namespace nd4j {
unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__
I *y; //int threadsPerSegment, start, finish;
__shared__ T *val;
__shared__ Nd4jLong xLen, zLen, zIndex;
__shared__ T *x;
__shared__ T *z;
__shared__ I *y; //int threadsPerSegment, start, finish;
auto segment = blockIdx.x;
if (threadIdx.x == 0) {
segment = blockIdx.x;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
y = reinterpret_cast<I *>(indices);
@ -127,9 +119,10 @@ namespace nd4j {
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ Nd4jLong len, zIndex, total;
__shared__ T* z;
__shared__ int start, finish;
__shared__ I segment;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
@ -143,19 +136,21 @@ namespace nd4j {
__syncthreads();
auto idx = blockIdx.x;
if (blockIdx.x <= total) {
if (idx <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
//z[zIndex] = x[xIndex];
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
if (lengths[segment])
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
}
}
@ -168,6 +163,7 @@ namespace nd4j {
//int numClasses = output->sizeAt(0);
// if input is a vector: (as if in doc sample)
//Nd4jLong idx = indices->e<Nd4jLong>(0);
output->assign(-DataTypeUtils::infOrMax<T>());
auto stream = context->getCudaStream();
indices->syncToHost();
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
@ -211,6 +207,8 @@ namespace nd4j {
static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
output->assign(DataTypeUtils::infOrMax<T>());
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
@ -243,6 +241,7 @@ namespace nd4j {
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
output->nullify();
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
}

View File

@ -38,19 +38,16 @@ namespace helpers {
static __global__ void
segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
void *output, Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__ T *val;
__shared__ Nd4jLong xLen, zLen, zIndex;
__shared__ T *x;
__shared__ T *z;
__shared__ int threadsPerSegment, start, finish;
auto segment = blockIdx.x;
if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
// segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
extern __shared__ unsigned char shmem[];
@ -123,12 +120,12 @@ namespace helpers {
template <typename T, typename I>
static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ Nd4jLong len, zIndex, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
auto segment = indices[blockIdx.x]; // / threadsPerSegment;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
@ -145,13 +142,14 @@ namespace helpers {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
// if (lengths[indices[idx]])
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
}
}
@ -165,7 +163,7 @@ namespace helpers {
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
output->assign(DataTypeUtils::infOrMax<T>());
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
@ -193,6 +191,7 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
output->nullify();
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
}
@ -207,6 +206,7 @@ namespace helpers {
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
output->assign(DataTypeUtils::infOrMax<T>());
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
@ -236,6 +236,7 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
output->nullify();
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});

View File

@ -146,13 +146,14 @@ namespace helpers {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
if (lengths[segment] > 0)
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
}
}
@ -166,7 +167,7 @@ namespace helpers {
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
output->assign(1);
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
@ -373,6 +374,7 @@ namespace helpers {
template <typename T, typename I>
static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});

View File

@ -121,12 +121,12 @@ namespace helpers {
template <typename T, typename I>
static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ Nd4jLong len, zIndex, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
__shared__ int start, finish;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
auto segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
@ -143,14 +143,14 @@ namespace helpers {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
if (lengths[segment])
if (lengths[indices[idx]])
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
}
}
@ -191,6 +191,7 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
output->nullify();
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
}
@ -232,6 +233,7 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
output->nullify();
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});

View File

@ -602,6 +602,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
// NDArray VT = outArrs[2]->transpose();
NDArray* V = outArrs[2];
NDArray::prepareSpecialUse({S, U, V}, {x});
if(x->rankOf() == 2) {
// svdQR(context, x, S, U, VT, fullUV, calcUV);
svdJcb(context, x, S, U, V, fullUV, calcUV);
@ -631,6 +633,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
delete tadsV;
}
}
NDArray::registerSpecialUse({S, U, V}, {x});
}

View File

@ -1920,6 +1920,50 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) {
delete result;
}
/* @Test
public void testDynamicPartition(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
.addIntegerArguments(3) //3 partitions
.addInputs(data, partitions).build());
INDArray exp0 = Nd4j.createFromArray(2, 0);
INDArray exp1 = Nd4j.createFromArray(2);
INDArray exp2 = Nd4j.createFromArray(1);
assertEquals(exp0, out[0]); //Usually just gives [0,0]
assertEquals(exp1, out[1]);
assertEquals(exp2, out[2]);
}*/
TEST_F(DeclarableOpsTests5, DynamicPartition_01) {
auto x = NDArrayFactory::create<int>({2,1,2,0});
auto y = NDArrayFactory::create<int>({0,2,1,0});
int numPartition = 3;
std::vector<NDArray> exp( { NDArrayFactory::create<int>('c', {2}, {2, 0}),
NDArrayFactory::create<int>('c', {1}, {2}),
NDArrayFactory::create<int>('c', {1}, {1})});
nd4j::ops::dynamic_partition op;
auto result = op.execute({&x, &y}, {}, {numPartition});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4
for (int e = 0; e < result->size(); e++) {
auto output = result->at(e);
// output->printShapeInfo("Output shape> ");
// output->printIndexedBuffer("Output data> ");
ASSERT_TRUE(exp[e].isSameShape(output));
ASSERT_TRUE(exp[e].equalsTo(output));
}
delete result;
}
TEST_F(DeclarableOpsTests5, DynamicPartition_1) {
@ -2031,6 +2075,38 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) {
delete result;
}
TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) {
auto i0 = NDArrayFactory::create<int>('c', {2}, {2, 3});
auto i1 = NDArrayFactory::empty<int>();
auto i2 = NDArrayFactory::create<int>('c', {2}, {0, 1});
auto d0 = NDArrayFactory::create<double>('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126});
auto d1 = NDArrayFactory::empty<double>();
auto d2 = NDArrayFactory::create<double>('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326});
nd4j::ops::dynamic_stitch op;
auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) {
auto i0 = NDArrayFactory::create<int>('c', {2}, {2, 3});
auto i1 = NDArrayFactory::create<int>('c', {0});
auto i2 = NDArrayFactory::create<int>('c', {2}, {0, 1});
auto d0 = NDArrayFactory::create<double>('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126});
auto d1 = NDArrayFactory::create<double>('c', {0, 5});
auto d2 = NDArrayFactory::create<double>('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326});
nd4j::ops::dynamic_stitch op;
auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, DynamicStitch_1) {

View File

@ -1728,6 +1728,24 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, out);
}
@Test
public void testDynamicPartition(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
.addIntegerArguments(3) //3 partitions
.addInputs(data, partitions).build());
INDArray exp0 = Nd4j.createFromArray(2, 0);
INDArray exp1 = Nd4j.createFromArray(2);
INDArray exp2 = Nd4j.createFromArray(1);
assertEquals(exp0, out[0]); //Usually just gives [0,0]
assertEquals(exp1, out[1]);
assertEquals(exp2, out[2]);
}
@Test
public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);

View File

@ -344,6 +344,8 @@ public class TFGraphTestAllHelper {
System.out.println("Pass: " + varName);
} else {
System.out.println("FAIL: " + varName);
System.out.println("TF:\n" + tfValue);
System.out.println("SD:\n" + sdVal);
}
}

View File

@ -180,8 +180,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
try {
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs);
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
} catch (Throwable t){
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
throw t;