diff --git a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp index 45ec30b58..af4e96e2e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp @@ -152,7 +152,6 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector if (end > tads) end = tads; for (int r = start; r < end; r++) { - if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) { auto rX = const_cast(input)->bufferAsT() + tadOffsets[r]; auto rZ = output->bufferAsT() + zOfsets[r]; @@ -198,44 +197,6 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector rZ[zOffset] = maxIdx == i ? (Z) 1 : (Z) 0; } } - } - else { - int tadsPerThread = tads / TAD_THRESHOLD; - int num_threads = nd4j::math::nd4j_max(1, tadsPerThread); - num_threads = nd4j::math::nd4j_min(num_threads, omp_get_max_threads()); - - Nd4jLong offset = tadOffsets[r]; - Nd4jLong shapeIter[MAX_RANK]; - Nd4jLong coord[MAX_RANK]; - int dim; - Nd4jLong xStridesIter[MAX_RANK]; - Nd4jLong resultStridesIter[MAX_RANK]; - Nd4jLong *xShape = shape::shapeOf(tadShapeShapeInfo); - Nd4jLong *xStride = shape::stride(tadShapeShapeInfo); - Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo); - int rank = shape::rank(tadShapeShapeInfo); - auto xPointer = const_cast(input)->bufferAsT() + offset; - auto resultPointer = output->bufferAsT() + offset; - auto maxValue = xPointer[0]; - - auto maxCursor = resultPointer; - Nd4jPointer maxCursorLong = reinterpret_cast(maxCursor); - - if (PrepareTwoRawArrayIter(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) { - ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); - { - if (maxValue < xPointer[0]) { - maxCursor = resultPointer; - maxCursorLong = reinterpret_cast(resultPointer); - maxValue = xPointer[0]; - } - resultPointer[0] = (Z) 0; - } - ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter); - maxCursor = reinterpret_cast(maxCursorLong); - maxCursor[0] = (Z) 1; - } - } } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 012b974fb..121ca2b43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -6719,7 +6719,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public boolean wasClosed() { - if (released || data().wasClosed()) + // data can be null if that's empty array + if (released || (data() != null && data().wasClosed())) return true; return false; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index a1bc56703..d7f5746cc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -643,6 +643,20 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(result1, result2); } + @Test + public void isMax4d_2dims(){ + Nd4j.getRandom().setSeed(12345); + INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1); + + INDArray out_permutedIn = in.like(); + INDArray out_dupedIn = in.like(); + + Nd4j.exec(new IsMax(in.dup(), out_dupedIn, 2, 3)); + Nd4j.exec(new IsMax(in, out_permutedIn, 2, 3)); + + assertEquals(out_dupedIn, out_permutedIn); + } + @Test public void testSizeTypes(){ List failed = new ArrayList<>();