[WIP] cpu ismax fix (#137)

* cpu ismax fix

Signed-off-by: raver119 <raver119@gmail.com>

* bunch of smaller scalar tests

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-21 10:12:11 +03:00 committed by GitHub
parent 4211f3b4ce
commit 77805cb7fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 98 additions and 37 deletions

View File

@ -28,43 +28,43 @@ namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
template <typename X, typename Z>
static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
if (input->isVector()) {
int dimensionsLength = dimensions.size();
int length = input->lengthOf();
if ((input->shapeOf())[dimensions[0]] == 1) {
if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) {
for (int i = 0; i < length; i++)
output->p<T>(i, 1.f);
output->p<Z>(i, 1);
}
else {
int eleStride = shape::elementWiseStride(input->getShapeInfo());
if (eleStride == 1) {
int maxIdx = 0;
T currMax = input->e<T>(0);
auto currMax = input->e<X>(0);
if (length < ELEMENT_THRESHOLD) {
for (int i = 0; i < length; i++) {
if (currMax < input->e<T>(i)) {
currMax = input->e<T>(i);
if (currMax < input->e<X>(i)) {
currMax = input->e<X>(i);
maxIdx = i;
}
output->p<T>(i, 0.f);
output->p<Z>(i, 0);
}
}
else {
{
int maxIdxLocal = maxIdx;
T currMaxLocal = currMax;
auto currMaxLocal = currMax;
for (int i = 0; i < length; i++) {
if (currMaxLocal < input->e<T>(i)) {
currMaxLocal = input->e<T>(i);
if (currMaxLocal < input->e<X>(i)) {
currMaxLocal = input->e<X>(i);
maxIdxLocal = i;
}
output->p<T>(i, 0.f);
output->p<Z>(i, 0);
}
PRAGMA_OMP_CRITICAL
@ -76,32 +76,32 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
}
}
}
output->p<T>(maxIdx, 1.f);
output->p<Z>(maxIdx, 1);
}
else {
int maxIdx = 0;
T currMax = input->e<T>(0);
auto currMax = input->e<X>(0);
if (length < ELEMENT_THRESHOLD) {
for (int i = 0; i < length; i++) {
if (currMax < input->e<T>(i*eleStride)) {
currMax = input->e<T>(i*eleStride);
if (currMax < input->e<X>(i*eleStride)) {
currMax = input->e<X>(i*eleStride);
maxIdx = i;
}
output->p<T>(i, 0.f);
output->p<Z>(i, 0.f);
}
}
else {
{
int maxIdxLocal = maxIdx;
T currMaxLocal = currMax;
auto currMaxLocal = currMax;
for (int i = 0; i < length; i++) {
if (currMaxLocal < input->e<T>(i*eleStride)) {
currMaxLocal = input->e<T>(i*eleStride);
if (currMaxLocal < input->e<X>(i*eleStride)) {
currMaxLocal = input->e<X>(i*eleStride);
maxIdxLocal = i;
}
output->p<T>(i, 0.f);
output->p<Z>(i, 0.f);
}
PRAGMA_OMP_CRITICAL
@ -113,7 +113,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
}
}
}
output->p<T>(maxIdx, 1.f);
output->p<Z>(maxIdx, 1);
}
}
}
@ -150,10 +150,10 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
for (int r = start; r < end; r++) {
if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) {
T *rX = const_cast<NDArray*>(input)->bufferAsT<T>() + tadOffsets[r];
T *rZ = output->bufferAsT<T>() + tadOffsets[r];
auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r];
auto rZ = output->bufferAsT<Z>() + tadOffsets[r];
T maxValue = rX[0];
auto maxValue = rX[0];
int maxIdx = 0;
if (tadEWS == 1 && zEWS == 1) {
for (int i = 0; i < tadLength; i++) {
@ -165,7 +165,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
PRAGMA_OMP_SIMD
for (int i = 0; i < tadLength; i++) {
rZ[i] = maxIdx == i ? (T) 1.0 : (T) 0.0;
rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0;
}
}
else {
@ -178,7 +178,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
PRAGMA_OMP_SIMD
for (int i = 0; i < tadLength; i++) {
rZ[i * zEWS] = maxIdx == i ? (T) 1.0 : (T) 0.0;
rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0;
}
}
}
@ -197,14 +197,14 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
Nd4jLong *xStride = shape::stride(tadShapeShapeInfo);
Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo);
int rank = shape::rank(tadShapeShapeInfo);
T *xPointer = const_cast<NDArray*>(input)->bufferAsT<T>() + offset;
T *resultPointer = output->bufferAsT<T>() + offset;
T maxValue = xPointer[0];
auto xPointer = const_cast<NDArray*>(input)->bufferAsT<X>() + offset;
auto resultPointer = output->bufferAsT<Z>() + offset;
auto maxValue = xPointer[0];
T *maxCursor = resultPointer;
auto maxCursor = resultPointer;
Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
if (PrepareTwoRawArrayIter<T>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
if (PrepareTwoRawArrayIter<X,Z>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter);
{
if (maxValue < xPointer[0]) {
@ -212,11 +212,11 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
maxValue = xPointer[0];
}
resultPointer[0] = 0.0;
resultPointer[0] = (Z) 0;
}
ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter);
maxCursor = reinterpret_cast<T*>(maxCursorLong);
maxCursor[0] = 1.0;
maxCursor = reinterpret_cast<Z*>(maxCursorLong);
maxCursor[0] = (Z) 1;
}
}
}
@ -226,10 +226,9 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES, LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void ismax_, (const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES);
}
}

View File

@ -2346,6 +2346,18 @@ TEST_F(DeclarableOpsTests1, IsMax3) {
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, IsMax4) {
auto x = NDArrayFactory::create<double>('c', {6}, {0, 0, 0, 2, 2, 0});
auto z = NDArrayFactory::create<bool>('c', {6});
auto e = NDArrayFactory::create<bool>('c', {6}, {false, false, false, true, false, false});
nd4j::ops::ismax op;
auto result = op.execute({&x}, {&z}, {}, {}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, CompactLaunchTests1) {

View File

@ -104,7 +104,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
@Test
public void testBasicCreation_5() {
val scalar = Nd4j.scalar(new Integer(1));
val scalar = Nd4j.scalar(Integer.valueOf(1));
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
@ -112,6 +112,56 @@ public class MixedDataTypesTests extends BaseNd4jTest {
assertEquals(1.0, scalar.getInt(0), 1e-5);
}
@Test
public void testBasicCreation_5_0() {
val scalar = Nd4j.scalar(Long.valueOf(1));
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
assertEquals(DataType.LONG, scalar.dataType());
assertEquals(1.0, scalar.getInt(0), 1e-5);
}
@Test
public void testBasicCreation_5_1() {
val scalar = Nd4j.scalar(Double.valueOf(1));
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
assertEquals(DataType.DOUBLE, scalar.dataType());
assertEquals(1.0, scalar.getDouble(0), 1e-5);
}
@Test
public void testBasicCreation_5_2() {
val scalar = Nd4j.scalar(Float.valueOf(1));
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
assertEquals(DataType.FLOAT, scalar.dataType());
assertEquals(1.0, scalar.getDouble(0), 1e-5);
}
@Test
public void testBasicCreation_5_3() {
val scalar = Nd4j.scalar(Short.valueOf((short) 1));
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
assertEquals(DataType.SHORT, scalar.dataType());
assertEquals(1.0, scalar.getDouble(0), 1e-5);
}
@Test
public void testBasicCreation_5_4() {
val scalar = Nd4j.scalar(Byte.valueOf((byte) 1));
assertNotNull(scalar);
assertEquals(0, scalar.rank());
assertEquals(1, scalar.length());
assertEquals(DataType.BYTE, scalar.dataType());
assertEquals(1.0, scalar.getDouble(0), 1e-5);
}
@Test
public void testBasicCreation_6() {
val scalar = Nd4j.scalar(1);