diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index 857ebed38..92e5b38b4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -50,8 +50,8 @@ namespace nd4j { auto zShapeInfo = zShapeInfos[o]; auto zLength = shape::length(zShapeInfo); - // iLimit should be - auto iLimit = iLength <= blockIdx.x ? blockIdx.x : (iLength + (blockIdx.x - (iLength % blockIdx.x))); + // iLimit should be multiple of blockDim.x + auto iLimit = iLength <= blockDim.x ? blockDim.x : (iLength + (blockDim.x - (iLength % blockDim.x))); int cnt = 0; for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) { @@ -75,8 +75,9 @@ namespace nd4j { // doing actual update if (e < iLength) - if (trueIndices[threadIdx.x] >= 0) + if (trueIndices[threadIdx.x] >= 0) { z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo, xLength)]; + } __syncthreads(); } @@ -148,13 +149,12 @@ namespace nd4j { auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); - dynamicPartitionTadKernel<<<256, 512, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); + dynamicPartitionTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); } else { auto numThreads = 256; auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; - std::vector outBuffers; std::vector outShapes; @@ -203,6 +203,9 @@ namespace nd4j { auto indices = reinterpret_cast(vindices[e]); auto iShapeInfo = iShapeInfos[e]; + if (shape::isEmpty(iShapeInfo)) + continue; + auto iLength = shape::length(iShapeInfo); auto zLength = shape::length(zTadShapeInfo); @@ -310,8 +313,9 @@ namespace nd4j { NDArray::registerSpecialUse({}, {indices, input}); - for (auto v:outputList) + for (auto v:outputList) { v->tickWriteDevice(); + } } template diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index dc91a2704..8830f37e7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -40,19 +40,16 @@ namespace nd4j { static __global__ void segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, void *output, Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - T *x; - __shared__ - T *z; + __shared__ T *val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T *x; + __shared__ T *z; __shared__ int threadsPerSegment, start, finish; + auto segment = blockIdx.x; if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; +// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; +// segment = blockIdx.x / threadsPerSegment; x = reinterpret_cast(input); z = reinterpret_cast(output); extern __shared__ unsigned char shmem[]; @@ -83,19 +80,14 @@ namespace nd4j { unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape, int *starts, int *lengths, Nd4jLong numOfClasses, void *output, Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - T *x; - __shared__ - T *z; - __shared__ - I *y; //int threadsPerSegment, start, finish; + __shared__ T *val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T *x; + __shared__ T *z; + __shared__ I *y; //int threadsPerSegment, start, finish; + auto segment = blockIdx.x; if (threadIdx.x == 0) { - segment = blockIdx.x; x = reinterpret_cast(input); z = reinterpret_cast(output); y = reinterpret_cast(indices); @@ -127,9 +119,10 @@ namespace nd4j { Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) { __shared__ T* val; - __shared__ Nd4jLong len, segment, zIndex, total; + __shared__ Nd4jLong len, zIndex, total; __shared__ T* z; __shared__ int start, finish; + __shared__ I segment; if (threadIdx.x == 0) { segment = indices[blockIdx.x]; // / threadsPerSegment; @@ -143,20 +136,22 @@ namespace nd4j { __syncthreads(); auto idx = blockIdx.x; - if (blockIdx.x <= total) { + if (idx <= total) { auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; if (blockIdx.x == start) { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + //z[zIndex] = x[xIndex]; } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + if (lengths[segment]) + nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); } } } @@ -168,6 +163,7 @@ namespace nd4j { //int numClasses = output->sizeAt(0); // if input is a vector: (as if in doc sample) //Nd4jLong idx = indices->e(0); + output->assign(-DataTypeUtils::infOrMax()); auto stream = context->getCudaStream(); indices->syncToHost(); Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; @@ -211,6 +207,8 @@ namespace nd4j { static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); + output->assign(DataTypeUtils::infOrMax()); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); @@ -243,6 +241,7 @@ namespace nd4j { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, indices}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index 950abde67..e5ea2eb91 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -38,19 +38,16 @@ namespace helpers { static __global__ void segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, void *output, Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - T *x; - __shared__ - T *z; + __shared__ T *val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T *x; + __shared__ T *z; __shared__ int threadsPerSegment, start, finish; + auto segment = blockIdx.x; if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; +// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; +// segment = blockIdx.x / threadsPerSegment; x = reinterpret_cast(input); z = reinterpret_cast(output); extern __shared__ unsigned char shmem[]; @@ -123,12 +120,12 @@ namespace helpers { template static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { __shared__ T* val; - __shared__ Nd4jLong len, segment, zIndex, total; + __shared__ Nd4jLong len, zIndex, total; __shared__ T* z; __shared__ int threadsPerSegment, start, finish; + auto segment = indices[blockIdx.x]; // / threadsPerSegment; if (threadIdx.x == 0) { - segment = indices[blockIdx.x]; // / threadsPerSegment; z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; len = shape::length(inputTads); start = starts[segment]; @@ -145,14 +142,15 @@ namespace helpers { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); +// if (lengths[indices[idx]]) + nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); } } } @@ -165,7 +163,7 @@ namespace helpers { Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); - + output->assign(DataTypeUtils::infOrMax()); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -193,6 +191,7 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, indices}); } @@ -207,6 +206,7 @@ namespace helpers { NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); + output->assign(DataTypeUtils::infOrMax()); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); @@ -236,6 +236,7 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, indices}); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index 3ae4ebcb8..5709a63ea 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -146,14 +146,15 @@ namespace helpers { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); + if (lengths[segment] > 0) + nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); } } } @@ -166,7 +167,7 @@ namespace helpers { Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); - + output->assign(1); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -373,6 +374,7 @@ namespace helpers { template static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); unsortedSegmentProdFunctor_(context, input, indices, numOfClasses, &tempRes); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 08b36253a..4b8976f4e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -121,12 +121,12 @@ namespace helpers { template static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { __shared__ T* val; - __shared__ Nd4jLong len, segment, zIndex, total; + __shared__ Nd4jLong len, zIndex, total; __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; + __shared__ int start, finish; if (threadIdx.x == 0) { - segment = indices[blockIdx.x]; // / threadsPerSegment; + auto segment = indices[blockIdx.x]; // / threadsPerSegment; z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; len = shape::length(inputTads); start = starts[segment]; @@ -143,14 +143,14 @@ namespace helpers { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - if (lengths[segment]) + if (lengths[indices[idx]]) nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); } } @@ -191,6 +191,7 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, indices}); } @@ -232,6 +233,7 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, indices}); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index da4f5cc86..0695119da 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -602,6 +602,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vectortranspose(); NDArray* V = outArrs[2]; + NDArray::prepareSpecialUse({S, U, V}, {x}); + if(x->rankOf() == 2) { // svdQR(context, x, S, U, VT, fullUV, calcUV); svdJcb(context, x, S, U, V, fullUV, calcUV); @@ -631,6 +633,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector({2,1,2,0}); + + auto y = NDArrayFactory::create({0,2,1,0}); + + int numPartition = 3; + std::vector exp( { NDArrayFactory::create('c', {2}, {2, 0}), + NDArrayFactory::create('c', {1}, {2}), + NDArrayFactory::create('c', {1}, {1})}); + + nd4j::ops::dynamic_partition op; + auto result = op.execute({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result->size(); e++) { + auto output = result->at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } + + delete result; +} TEST_F(DeclarableOpsTests5, DynamicPartition_1) { @@ -2031,6 +2075,38 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) { delete result; } +TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::empty(); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + + auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); + auto d1 = NDArrayFactory::empty(); + auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + + nd4j::ops::dynamic_stitch op; + auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + +TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::create('c', {0}); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + + auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); + auto d1 = NDArrayFactory::create('c', {0, 5}); + auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + + nd4j::ops::dynamic_stitch op; + auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_1) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 99a2f57ac..eb228bf1f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -1728,6 +1728,24 @@ public class MiscOpValidation extends BaseOpValidation { assertEquals(exp, out); } + @Test + public void testDynamicPartition(){ + INDArray data = Nd4j.createFromArray(2, 1, 2, 0); + INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); + INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) + .addIntegerArguments(3) //3 partitions + .addInputs(data, partitions).build()); + + INDArray exp0 = Nd4j.createFromArray(2, 0); + INDArray exp1 = Nd4j.createFromArray(2); + INDArray exp2 = Nd4j.createFromArray(1); + + assertEquals(exp0, out[0]); //Usually just gives [0,0] + assertEquals(exp1, out[1]); + assertEquals(exp2, out[2]); + } + @Test public void testListDiff(){ INDArray x = Nd4j.createFromArray(0, 1, 2, 3); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index ab86f829e..ecc81c981 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -344,6 +344,8 @@ public class TFGraphTestAllHelper { System.out.println("Pass: " + varName); } else { System.out.println("FAIL: " + varName); + System.out.println("TF:\n" + tfValue); + System.out.println("SD:\n" + sdVal); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 4f3520d31..1fef4a07b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -180,8 +180,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); try { - TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, - TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs); + //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); } catch (Throwable t){ log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t); throw t;