diff --git a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp index a26e0be20..330e23d5b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp @@ -28,43 +28,43 @@ namespace nd4j { namespace ops { namespace helpers { -template +template static void ismax_(const NDArray* input, NDArray* output, const std::vector& dimensions) { if (input->isVector()) { int dimensionsLength = dimensions.size(); int length = input->lengthOf(); - if ((input->shapeOf())[dimensions[0]] == 1) { + if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) { for (int i = 0; i < length; i++) - output->p(i, 1.f); + output->p(i, 1); } else { int eleStride = shape::elementWiseStride(input->getShapeInfo()); if (eleStride == 1) { int maxIdx = 0; - T currMax = input->e(0); + auto currMax = input->e(0); if (length < ELEMENT_THRESHOLD) { for (int i = 0; i < length; i++) { - if (currMax < input->e(i)) { - currMax = input->e(i); + if (currMax < input->e(i)) { + currMax = input->e(i); maxIdx = i; } - output->p(i, 0.f); + output->p(i, 0); } } else { { int maxIdxLocal = maxIdx; - T currMaxLocal = currMax; + auto currMaxLocal = currMax; for (int i = 0; i < length; i++) { - if (currMaxLocal < input->e(i)) { - currMaxLocal = input->e(i); + if (currMaxLocal < input->e(i)) { + currMaxLocal = input->e(i); maxIdxLocal = i; } - output->p(i, 0.f); + output->p(i, 0); } PRAGMA_OMP_CRITICAL @@ -76,32 +76,32 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector } } } - output->p(maxIdx, 1.f); + output->p(maxIdx, 1); } else { int maxIdx = 0; - T currMax = input->e(0); + auto currMax = input->e(0); if (length < ELEMENT_THRESHOLD) { for (int i = 0; i < length; i++) { - if (currMax < input->e(i*eleStride)) { - currMax = input->e(i*eleStride); + if (currMax < input->e(i*eleStride)) { + currMax = input->e(i*eleStride); maxIdx = i; } - output->p(i, 0.f); + output->p(i, 0.f); } } else { { int maxIdxLocal = maxIdx; - T currMaxLocal = currMax; + auto currMaxLocal = currMax; for (int i = 0; i < length; i++) { - if (currMaxLocal < input->e(i*eleStride)) { - currMaxLocal = input->e(i*eleStride); + if (currMaxLocal < input->e(i*eleStride)) { + currMaxLocal = input->e(i*eleStride); maxIdxLocal = i; } - output->p(i, 0.f); + output->p(i, 0.f); } PRAGMA_OMP_CRITICAL @@ -113,7 +113,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector } } } - output->p(maxIdx, 1.f); + output->p(maxIdx, 1); } } } @@ -150,10 +150,10 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector for (int r = start; r < end; r++) { if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) { - T *rX = const_cast(input)->bufferAsT() + tadOffsets[r]; - T *rZ = output->bufferAsT() + tadOffsets[r]; + auto rX = const_cast(input)->bufferAsT() + tadOffsets[r]; + auto rZ = output->bufferAsT() + tadOffsets[r]; - T maxValue = rX[0]; + auto maxValue = rX[0]; int maxIdx = 0; if (tadEWS == 1 && zEWS == 1) { for (int i = 0; i < tadLength; i++) { @@ -165,7 +165,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector PRAGMA_OMP_SIMD for (int i = 0; i < tadLength; i++) { - rZ[i] = maxIdx == i ? (T) 1.0 : (T) 0.0; + rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0; } } else { @@ -178,7 +178,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector PRAGMA_OMP_SIMD for (int i = 0; i < tadLength; i++) { - rZ[i * zEWS] = maxIdx == i ? (T) 1.0 : (T) 0.0; + rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0; } } } @@ -197,14 +197,14 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector Nd4jLong *xStride = shape::stride(tadShapeShapeInfo); Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo); int rank = shape::rank(tadShapeShapeInfo); - T *xPointer = const_cast(input)->bufferAsT() + offset; - T *resultPointer = output->bufferAsT() + offset; - T maxValue = xPointer[0]; + auto xPointer = const_cast(input)->bufferAsT() + offset; + auto resultPointer = output->bufferAsT() + offset; + auto maxValue = xPointer[0]; - T *maxCursor = resultPointer; + auto maxCursor = resultPointer; Nd4jPointer maxCursorLong = reinterpret_cast(maxCursor); - if (PrepareTwoRawArrayIter(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) { + if (PrepareTwoRawArrayIter(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) { ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); { if (maxValue < xPointer[0]) { @@ -212,11 +212,11 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector maxCursorLong = reinterpret_cast(resultPointer); maxValue = xPointer[0]; } - resultPointer[0] = 0.0; + resultPointer[0] = (Z) 0; } ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter); - maxCursor = reinterpret_cast(maxCursorLong); - maxCursor[0] = 1.0; + maxCursor = reinterpret_cast(maxCursorLong); + maxCursor[0] = (Z) 1; } } } @@ -226,10 +226,9 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions) { - BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES, LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void ismax_, (const NDArray *input, NDArray *output, const std::vector& dimensions), LIBND4J_TYPES); } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index c94c9c123..b731a37b4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -2346,6 +2346,18 @@ TEST_F(DeclarableOpsTests1, IsMax3) { delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, IsMax4) { + auto x = NDArrayFactory::create('c', {6}, {0, 0, 0, 2, 2, 0}); + auto z = NDArrayFactory::create('c', {6}); + auto e = NDArrayFactory::create('c', {6}, {false, false, false, true, false, false}); + + nd4j::ops::ismax op; + auto result = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + + ASSERT_EQ(e, z); +} ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, CompactLaunchTests1) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java index 81da2cbf8..19085629b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java @@ -104,7 +104,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { @Test public void testBasicCreation_5() { - val scalar = Nd4j.scalar(new Integer(1)); + val scalar = Nd4j.scalar(Integer.valueOf(1)); assertNotNull(scalar); assertEquals(0, scalar.rank()); assertEquals(1, scalar.length()); @@ -112,6 +112,56 @@ public class MixedDataTypesTests extends BaseNd4jTest { assertEquals(1.0, scalar.getInt(0), 1e-5); } + @Test + public void testBasicCreation_5_0() { + val scalar = Nd4j.scalar(Long.valueOf(1)); + assertNotNull(scalar); + assertEquals(0, scalar.rank()); + assertEquals(1, scalar.length()); + assertEquals(DataType.LONG, scalar.dataType()); + assertEquals(1.0, scalar.getInt(0), 1e-5); + } + + @Test + public void testBasicCreation_5_1() { + val scalar = Nd4j.scalar(Double.valueOf(1)); + assertNotNull(scalar); + assertEquals(0, scalar.rank()); + assertEquals(1, scalar.length()); + assertEquals(DataType.DOUBLE, scalar.dataType()); + assertEquals(1.0, scalar.getDouble(0), 1e-5); + } + + @Test + public void testBasicCreation_5_2() { + val scalar = Nd4j.scalar(Float.valueOf(1)); + assertNotNull(scalar); + assertEquals(0, scalar.rank()); + assertEquals(1, scalar.length()); + assertEquals(DataType.FLOAT, scalar.dataType()); + assertEquals(1.0, scalar.getDouble(0), 1e-5); + } + + @Test + public void testBasicCreation_5_3() { + val scalar = Nd4j.scalar(Short.valueOf((short) 1)); + assertNotNull(scalar); + assertEquals(0, scalar.rank()); + assertEquals(1, scalar.length()); + assertEquals(DataType.SHORT, scalar.dataType()); + assertEquals(1.0, scalar.getDouble(0), 1e-5); + } + + @Test + public void testBasicCreation_5_4() { + val scalar = Nd4j.scalar(Byte.valueOf((byte) 1)); + assertNotNull(scalar); + assertEquals(0, scalar.rank()); + assertEquals(1, scalar.length()); + assertEquals(DataType.BYTE, scalar.dataType()); + assertEquals(1.0, scalar.getDouble(0), 1e-5); + } + @Test public void testBasicCreation_6() { val scalar = Nd4j.scalar(1);