- cpu isMax fix for multidim case + test

- INDArray.wasClosed() fix for empty array edge case

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 18:44:37 +03:00
parent 2e99bc2dee
commit 99cdf6d42b
3 changed files with 16 additions and 40 deletions

View File

@ -152,7 +152,6 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
if (end > tads) end = tads;
for (int r = start; r < end; r++) {
if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) {
auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r];
auto rZ = output->bufferAsT<Z>() + zOfsets[r];
@ -199,44 +198,6 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
}
}
}
else {
int tadsPerThread = tads / TAD_THRESHOLD;
int num_threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads());
Nd4jLong offset = tadOffsets[r];
Nd4jLong shapeIter[MAX_RANK];
Nd4jLong coord[MAX_RANK];
int dim;
Nd4jLong xStridesIter[MAX_RANK];
Nd4jLong resultStridesIter[MAX_RANK];
Nd4jLong *xShape = shape::shapeOf(tadShapeShapeInfo);
Nd4jLong *xStride = shape::stride(tadShapeShapeInfo);
Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo);
int rank = shape::rank(tadShapeShapeInfo);
auto xPointer = const_cast<NDArray*>(input)->bufferAsT<X>() + offset;
auto resultPointer = output->bufferAsT<Z>() + offset;
auto maxValue = xPointer[0];
auto maxCursor = resultPointer;
Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
if (PrepareTwoRawArrayIter<X,Z>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter);
{
if (maxValue < xPointer[0]) {
maxCursor = resultPointer;
maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
maxValue = xPointer[0];
}
resultPointer[0] = (Z) 0;
}
ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter);
maxCursor = reinterpret_cast<Z*>(maxCursorLong);
maxCursor[0] = (Z) 1;
}
}
}
}
}
}

View File

@ -6719,7 +6719,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public boolean wasClosed() {
if (released || data().wasClosed())
// data can be null if that's empty array
if (released || (data() != null && data().wasClosed()))
return true;
return false;

View File

@ -643,6 +643,20 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(result1, result2);
}
@Test
public void isMax4d_2dims(){
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1);
INDArray out_permutedIn = in.like();
INDArray out_dupedIn = in.like();
Nd4j.exec(new IsMax(in.dup(), out_dupedIn, 2, 3));
Nd4j.exec(new IsMax(in, out_permutedIn, 2, 3));
assertEquals(out_dupedIn, out_permutedIn);
}
@Test
public void testSizeTypes(){
List<DataType> failed = new ArrayList<>();