- cpu isMax fix for multidim case + test
- INDArray.wasClosed() fix for empty array edge case Signed-off-by: raver119 <raver119@gmail.com>master
parent
2e99bc2dee
commit
99cdf6d42b
|
@ -152,7 +152,6 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
if (end > tads) end = tads;
|
if (end > tads) end = tads;
|
||||||
|
|
||||||
for (int r = start; r < end; r++) {
|
for (int r = start; r < end; r++) {
|
||||||
if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) {
|
|
||||||
auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r];
|
auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r];
|
||||||
auto rZ = output->bufferAsT<Z>() + zOfsets[r];
|
auto rZ = output->bufferAsT<Z>() + zOfsets[r];
|
||||||
|
|
||||||
|
@ -199,44 +198,6 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
|
||||||
int num_threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
|
||||||
num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads());
|
|
||||||
|
|
||||||
Nd4jLong offset = tadOffsets[r];
|
|
||||||
Nd4jLong shapeIter[MAX_RANK];
|
|
||||||
Nd4jLong coord[MAX_RANK];
|
|
||||||
int dim;
|
|
||||||
Nd4jLong xStridesIter[MAX_RANK];
|
|
||||||
Nd4jLong resultStridesIter[MAX_RANK];
|
|
||||||
Nd4jLong *xShape = shape::shapeOf(tadShapeShapeInfo);
|
|
||||||
Nd4jLong *xStride = shape::stride(tadShapeShapeInfo);
|
|
||||||
Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo);
|
|
||||||
int rank = shape::rank(tadShapeShapeInfo);
|
|
||||||
auto xPointer = const_cast<NDArray*>(input)->bufferAsT<X>() + offset;
|
|
||||||
auto resultPointer = output->bufferAsT<Z>() + offset;
|
|
||||||
auto maxValue = xPointer[0];
|
|
||||||
|
|
||||||
auto maxCursor = resultPointer;
|
|
||||||
Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
|
|
||||||
|
|
||||||
if (PrepareTwoRawArrayIter<X,Z>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
|
|
||||||
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter);
|
|
||||||
{
|
|
||||||
if (maxValue < xPointer[0]) {
|
|
||||||
maxCursor = resultPointer;
|
|
||||||
maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
|
|
||||||
maxValue = xPointer[0];
|
|
||||||
}
|
|
||||||
resultPointer[0] = (Z) 0;
|
|
||||||
}
|
|
||||||
ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter);
|
|
||||||
maxCursor = reinterpret_cast<Z*>(maxCursorLong);
|
|
||||||
maxCursor[0] = (Z) 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6719,7 +6719,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean wasClosed() {
|
public boolean wasClosed() {
|
||||||
if (released || data().wasClosed())
|
// data can be null if that's empty array
|
||||||
|
if (released || (data() != null && data().wasClosed()))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -643,6 +643,20 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(result1, result2);
|
assertEquals(result1, result2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void isMax4d_2dims(){
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1);
|
||||||
|
|
||||||
|
INDArray out_permutedIn = in.like();
|
||||||
|
INDArray out_dupedIn = in.like();
|
||||||
|
|
||||||
|
Nd4j.exec(new IsMax(in.dup(), out_dupedIn, 2, 3));
|
||||||
|
Nd4j.exec(new IsMax(in, out_permutedIn, 2, 3));
|
||||||
|
|
||||||
|
assertEquals(out_dupedIn, out_permutedIn);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSizeTypes(){
|
public void testSizeTypes(){
|
||||||
List<DataType> failed = new ArrayList<>();
|
List<DataType> failed = new ArrayList<>();
|
||||||
|
|
Loading…
Reference in New Issue