[WIP] cpu ismax fix (#137)
* cpu ismax fix Signed-off-by: raver119 <raver119@gmail.com> * bunch of smaller scalar tests Signed-off-by: raver119 <raver119@gmail.com>master
parent
4211f3b4ce
commit
77805cb7fa
|
@ -28,43 +28,43 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
template <typename X, typename Z>
|
||||
static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
|
||||
|
||||
if (input->isVector()) {
|
||||
int dimensionsLength = dimensions.size();
|
||||
int length = input->lengthOf();
|
||||
if ((input->shapeOf())[dimensions[0]] == 1) {
|
||||
if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) {
|
||||
for (int i = 0; i < length; i++)
|
||||
output->p<T>(i, 1.f);
|
||||
output->p<Z>(i, 1);
|
||||
}
|
||||
else {
|
||||
int eleStride = shape::elementWiseStride(input->getShapeInfo());
|
||||
if (eleStride == 1) {
|
||||
int maxIdx = 0;
|
||||
T currMax = input->e<T>(0);
|
||||
auto currMax = input->e<X>(0);
|
||||
if (length < ELEMENT_THRESHOLD) {
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (currMax < input->e<T>(i)) {
|
||||
currMax = input->e<T>(i);
|
||||
if (currMax < input->e<X>(i)) {
|
||||
currMax = input->e<X>(i);
|
||||
maxIdx = i;
|
||||
}
|
||||
output->p<T>(i, 0.f);
|
||||
output->p<Z>(i, 0);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
{
|
||||
int maxIdxLocal = maxIdx;
|
||||
T currMaxLocal = currMax;
|
||||
auto currMaxLocal = currMax;
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (currMaxLocal < input->e<T>(i)) {
|
||||
currMaxLocal = input->e<T>(i);
|
||||
if (currMaxLocal < input->e<X>(i)) {
|
||||
currMaxLocal = input->e<X>(i);
|
||||
maxIdxLocal = i;
|
||||
}
|
||||
output->p<T>(i, 0.f);
|
||||
output->p<Z>(i, 0);
|
||||
}
|
||||
|
||||
PRAGMA_OMP_CRITICAL
|
||||
|
@ -76,32 +76,32 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
}
|
||||
}
|
||||
}
|
||||
output->p<T>(maxIdx, 1.f);
|
||||
output->p<Z>(maxIdx, 1);
|
||||
}
|
||||
else {
|
||||
int maxIdx = 0;
|
||||
T currMax = input->e<T>(0);
|
||||
auto currMax = input->e<X>(0);
|
||||
if (length < ELEMENT_THRESHOLD) {
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (currMax < input->e<T>(i*eleStride)) {
|
||||
currMax = input->e<T>(i*eleStride);
|
||||
if (currMax < input->e<X>(i*eleStride)) {
|
||||
currMax = input->e<X>(i*eleStride);
|
||||
maxIdx = i;
|
||||
}
|
||||
output->p<T>(i, 0.f);
|
||||
output->p<Z>(i, 0.f);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
{
|
||||
int maxIdxLocal = maxIdx;
|
||||
T currMaxLocal = currMax;
|
||||
auto currMaxLocal = currMax;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (currMaxLocal < input->e<T>(i*eleStride)) {
|
||||
currMaxLocal = input->e<T>(i*eleStride);
|
||||
if (currMaxLocal < input->e<X>(i*eleStride)) {
|
||||
currMaxLocal = input->e<X>(i*eleStride);
|
||||
maxIdxLocal = i;
|
||||
}
|
||||
output->p<T>(i, 0.f);
|
||||
output->p<Z>(i, 0.f);
|
||||
}
|
||||
|
||||
PRAGMA_OMP_CRITICAL
|
||||
|
@ -113,7 +113,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
}
|
||||
}
|
||||
}
|
||||
output->p<T>(maxIdx, 1.f);
|
||||
output->p<Z>(maxIdx, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -150,10 +150,10 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
|
||||
for (int r = start; r < end; r++) {
|
||||
if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) {
|
||||
T *rX = const_cast<NDArray*>(input)->bufferAsT<T>() + tadOffsets[r];
|
||||
T *rZ = output->bufferAsT<T>() + tadOffsets[r];
|
||||
auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r];
|
||||
auto rZ = output->bufferAsT<Z>() + tadOffsets[r];
|
||||
|
||||
T maxValue = rX[0];
|
||||
auto maxValue = rX[0];
|
||||
int maxIdx = 0;
|
||||
if (tadEWS == 1 && zEWS == 1) {
|
||||
for (int i = 0; i < tadLength; i++) {
|
||||
|
@ -165,7 +165,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int i = 0; i < tadLength; i++) {
|
||||
rZ[i] = maxIdx == i ? (T) 1.0 : (T) 0.0;
|
||||
rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -178,7 +178,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int i = 0; i < tadLength; i++) {
|
||||
rZ[i * zEWS] = maxIdx == i ? (T) 1.0 : (T) 0.0;
|
||||
rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -197,14 +197,14 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
Nd4jLong *xStride = shape::stride(tadShapeShapeInfo);
|
||||
Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo);
|
||||
int rank = shape::rank(tadShapeShapeInfo);
|
||||
T *xPointer = const_cast<NDArray*>(input)->bufferAsT<T>() + offset;
|
||||
T *resultPointer = output->bufferAsT<T>() + offset;
|
||||
T maxValue = xPointer[0];
|
||||
auto xPointer = const_cast<NDArray*>(input)->bufferAsT<X>() + offset;
|
||||
auto resultPointer = output->bufferAsT<Z>() + offset;
|
||||
auto maxValue = xPointer[0];
|
||||
|
||||
T *maxCursor = resultPointer;
|
||||
auto maxCursor = resultPointer;
|
||||
Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
|
||||
|
||||
if (PrepareTwoRawArrayIter<T>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
|
||||
if (PrepareTwoRawArrayIter<X,Z>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
|
||||
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter);
|
||||
{
|
||||
if (maxValue < xPointer[0]) {
|
||||
|
@ -212,11 +212,11 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
|
||||
maxValue = xPointer[0];
|
||||
}
|
||||
resultPointer[0] = 0.0;
|
||||
resultPointer[0] = (Z) 0;
|
||||
}
|
||||
ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter);
|
||||
maxCursor = reinterpret_cast<T*>(maxCursorLong);
|
||||
maxCursor[0] = 1.0;
|
||||
maxCursor = reinterpret_cast<Z*>(maxCursorLong);
|
||||
maxCursor[0] = (Z) 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -226,10 +226,9 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
|
||||
|
||||
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void ismax_, (const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2346,6 +2346,18 @@ TEST_F(DeclarableOpsTests1, IsMax3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, IsMax4) {
|
||||
auto x = NDArrayFactory::create<double>('c', {6}, {0, 0, 0, 2, 2, 0});
|
||||
auto z = NDArrayFactory::create<bool>('c', {6});
|
||||
auto e = NDArrayFactory::create<bool>('c', {6}, {false, false, false, true, false, false});
|
||||
|
||||
nd4j::ops::ismax op;
|
||||
auto result = op.execute({&x}, {&z}, {}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, CompactLaunchTests1) {
|
||||
|
|
|
@ -104,7 +104,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testBasicCreation_5() {
|
||||
val scalar = Nd4j.scalar(new Integer(1));
|
||||
val scalar = Nd4j.scalar(Integer.valueOf(1));
|
||||
assertNotNull(scalar);
|
||||
assertEquals(0, scalar.rank());
|
||||
assertEquals(1, scalar.length());
|
||||
|
@ -112,6 +112,56 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
|||
assertEquals(1.0, scalar.getInt(0), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicCreation_5_0() {
|
||||
val scalar = Nd4j.scalar(Long.valueOf(1));
|
||||
assertNotNull(scalar);
|
||||
assertEquals(0, scalar.rank());
|
||||
assertEquals(1, scalar.length());
|
||||
assertEquals(DataType.LONG, scalar.dataType());
|
||||
assertEquals(1.0, scalar.getInt(0), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicCreation_5_1() {
|
||||
val scalar = Nd4j.scalar(Double.valueOf(1));
|
||||
assertNotNull(scalar);
|
||||
assertEquals(0, scalar.rank());
|
||||
assertEquals(1, scalar.length());
|
||||
assertEquals(DataType.DOUBLE, scalar.dataType());
|
||||
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicCreation_5_2() {
|
||||
val scalar = Nd4j.scalar(Float.valueOf(1));
|
||||
assertNotNull(scalar);
|
||||
assertEquals(0, scalar.rank());
|
||||
assertEquals(1, scalar.length());
|
||||
assertEquals(DataType.FLOAT, scalar.dataType());
|
||||
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicCreation_5_3() {
|
||||
val scalar = Nd4j.scalar(Short.valueOf((short) 1));
|
||||
assertNotNull(scalar);
|
||||
assertEquals(0, scalar.rank());
|
||||
assertEquals(1, scalar.length());
|
||||
assertEquals(DataType.SHORT, scalar.dataType());
|
||||
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicCreation_5_4() {
|
||||
val scalar = Nd4j.scalar(Byte.valueOf((byte) 1));
|
||||
assertNotNull(scalar);
|
||||
assertEquals(0, scalar.rank());
|
||||
assertEquals(1, scalar.length());
|
||||
assertEquals(DataType.BYTE, scalar.dataType());
|
||||
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicCreation_6() {
|
||||
val scalar = Nd4j.scalar(1);
|
||||
|
|
Loading…
Reference in New Issue