[WIP] cpu ismax fix (#137)
* cpu ismax fix Signed-off-by: raver119 <raver119@gmail.com> * bunch of smaller scalar tests Signed-off-by: raver119 <raver119@gmail.com>master
parent
4211f3b4ce
commit
77805cb7fa
|
@ -28,43 +28,43 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename X, typename Z>
|
||||||
static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
|
static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
|
||||||
|
|
||||||
if (input->isVector()) {
|
if (input->isVector()) {
|
||||||
int dimensionsLength = dimensions.size();
|
int dimensionsLength = dimensions.size();
|
||||||
int length = input->lengthOf();
|
int length = input->lengthOf();
|
||||||
if ((input->shapeOf())[dimensions[0]] == 1) {
|
if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) {
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
output->p<T>(i, 1.f);
|
output->p<Z>(i, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
int eleStride = shape::elementWiseStride(input->getShapeInfo());
|
int eleStride = shape::elementWiseStride(input->getShapeInfo());
|
||||||
if (eleStride == 1) {
|
if (eleStride == 1) {
|
||||||
int maxIdx = 0;
|
int maxIdx = 0;
|
||||||
T currMax = input->e<T>(0);
|
auto currMax = input->e<X>(0);
|
||||||
if (length < ELEMENT_THRESHOLD) {
|
if (length < ELEMENT_THRESHOLD) {
|
||||||
|
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
if (currMax < input->e<T>(i)) {
|
if (currMax < input->e<X>(i)) {
|
||||||
currMax = input->e<T>(i);
|
currMax = input->e<X>(i);
|
||||||
maxIdx = i;
|
maxIdx = i;
|
||||||
}
|
}
|
||||||
output->p<T>(i, 0.f);
|
output->p<Z>(i, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
{
|
{
|
||||||
int maxIdxLocal = maxIdx;
|
int maxIdxLocal = maxIdx;
|
||||||
T currMaxLocal = currMax;
|
auto currMaxLocal = currMax;
|
||||||
|
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
if (currMaxLocal < input->e<T>(i)) {
|
if (currMaxLocal < input->e<X>(i)) {
|
||||||
currMaxLocal = input->e<T>(i);
|
currMaxLocal = input->e<X>(i);
|
||||||
maxIdxLocal = i;
|
maxIdxLocal = i;
|
||||||
}
|
}
|
||||||
output->p<T>(i, 0.f);
|
output->p<Z>(i, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
PRAGMA_OMP_CRITICAL
|
||||||
|
@ -76,32 +76,32 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output->p<T>(maxIdx, 1.f);
|
output->p<Z>(maxIdx, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
int maxIdx = 0;
|
int maxIdx = 0;
|
||||||
T currMax = input->e<T>(0);
|
auto currMax = input->e<X>(0);
|
||||||
if (length < ELEMENT_THRESHOLD) {
|
if (length < ELEMENT_THRESHOLD) {
|
||||||
|
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
if (currMax < input->e<T>(i*eleStride)) {
|
if (currMax < input->e<X>(i*eleStride)) {
|
||||||
currMax = input->e<T>(i*eleStride);
|
currMax = input->e<X>(i*eleStride);
|
||||||
maxIdx = i;
|
maxIdx = i;
|
||||||
}
|
}
|
||||||
output->p<T>(i, 0.f);
|
output->p<Z>(i, 0.f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
{
|
{
|
||||||
int maxIdxLocal = maxIdx;
|
int maxIdxLocal = maxIdx;
|
||||||
T currMaxLocal = currMax;
|
auto currMaxLocal = currMax;
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
if (currMaxLocal < input->e<T>(i*eleStride)) {
|
if (currMaxLocal < input->e<X>(i*eleStride)) {
|
||||||
currMaxLocal = input->e<T>(i*eleStride);
|
currMaxLocal = input->e<X>(i*eleStride);
|
||||||
maxIdxLocal = i;
|
maxIdxLocal = i;
|
||||||
}
|
}
|
||||||
output->p<T>(i, 0.f);
|
output->p<Z>(i, 0.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
PRAGMA_OMP_CRITICAL
|
||||||
|
@ -113,7 +113,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output->p<T>(maxIdx, 1.f);
|
output->p<Z>(maxIdx, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,10 +150,10 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
|
|
||||||
for (int r = start; r < end; r++) {
|
for (int r = start; r < end; r++) {
|
||||||
if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) {
|
if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) {
|
||||||
T *rX = const_cast<NDArray*>(input)->bufferAsT<T>() + tadOffsets[r];
|
auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r];
|
||||||
T *rZ = output->bufferAsT<T>() + tadOffsets[r];
|
auto rZ = output->bufferAsT<Z>() + tadOffsets[r];
|
||||||
|
|
||||||
T maxValue = rX[0];
|
auto maxValue = rX[0];
|
||||||
int maxIdx = 0;
|
int maxIdx = 0;
|
||||||
if (tadEWS == 1 && zEWS == 1) {
|
if (tadEWS == 1 && zEWS == 1) {
|
||||||
for (int i = 0; i < tadLength; i++) {
|
for (int i = 0; i < tadLength; i++) {
|
||||||
|
@ -165,7 +165,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int i = 0; i < tadLength; i++) {
|
for (int i = 0; i < tadLength; i++) {
|
||||||
rZ[i] = maxIdx == i ? (T) 1.0 : (T) 0.0;
|
rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -178,7 +178,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int i = 0; i < tadLength; i++) {
|
for (int i = 0; i < tadLength; i++) {
|
||||||
rZ[i * zEWS] = maxIdx == i ? (T) 1.0 : (T) 0.0;
|
rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -197,14 +197,14 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
Nd4jLong *xStride = shape::stride(tadShapeShapeInfo);
|
Nd4jLong *xStride = shape::stride(tadShapeShapeInfo);
|
||||||
Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo);
|
Nd4jLong *resultStride = shape::stride(tadShapeShapeInfo);
|
||||||
int rank = shape::rank(tadShapeShapeInfo);
|
int rank = shape::rank(tadShapeShapeInfo);
|
||||||
T *xPointer = const_cast<NDArray*>(input)->bufferAsT<T>() + offset;
|
auto xPointer = const_cast<NDArray*>(input)->bufferAsT<X>() + offset;
|
||||||
T *resultPointer = output->bufferAsT<T>() + offset;
|
auto resultPointer = output->bufferAsT<Z>() + offset;
|
||||||
T maxValue = xPointer[0];
|
auto maxValue = xPointer[0];
|
||||||
|
|
||||||
T *maxCursor = resultPointer;
|
auto maxCursor = resultPointer;
|
||||||
Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
|
Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
|
||||||
|
|
||||||
if (PrepareTwoRawArrayIter<T>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
|
if (PrepareTwoRawArrayIter<X,Z>(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) {
|
||||||
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter);
|
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter);
|
||||||
{
|
{
|
||||||
if (maxValue < xPointer[0]) {
|
if (maxValue < xPointer[0]) {
|
||||||
|
@ -212,11 +212,11 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
|
maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
|
||||||
maxValue = xPointer[0];
|
maxValue = xPointer[0];
|
||||||
}
|
}
|
||||||
resultPointer[0] = 0.0;
|
resultPointer[0] = (Z) 0;
|
||||||
}
|
}
|
||||||
ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter);
|
ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter);
|
||||||
maxCursor = reinterpret_cast<T*>(maxCursorLong);
|
maxCursor = reinterpret_cast<Z*>(maxCursorLong);
|
||||||
maxCursor[0] = 1.0;
|
maxCursor[0] = (Z) 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -226,10 +226,9 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
|
|
||||||
|
|
||||||
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
|
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void ismax_, (const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2346,6 +2346,18 @@ TEST_F(DeclarableOpsTests1, IsMax3) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, IsMax4) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {6}, {0, 0, 0, 2, 2, 0});
|
||||||
|
auto z = NDArrayFactory::create<bool>('c', {6});
|
||||||
|
auto e = NDArrayFactory::create<bool>('c', {6}, {false, false, false, true, false, false});
|
||||||
|
|
||||||
|
nd4j::ops::ismax op;
|
||||||
|
auto result = op.execute({&x}, {&z}, {}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, CompactLaunchTests1) {
|
TEST_F(DeclarableOpsTests1, CompactLaunchTests1) {
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicCreation_5() {
|
public void testBasicCreation_5() {
|
||||||
val scalar = Nd4j.scalar(new Integer(1));
|
val scalar = Nd4j.scalar(Integer.valueOf(1));
|
||||||
assertNotNull(scalar);
|
assertNotNull(scalar);
|
||||||
assertEquals(0, scalar.rank());
|
assertEquals(0, scalar.rank());
|
||||||
assertEquals(1, scalar.length());
|
assertEquals(1, scalar.length());
|
||||||
|
@ -112,6 +112,56 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
assertEquals(1.0, scalar.getInt(0), 1e-5);
|
assertEquals(1.0, scalar.getInt(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicCreation_5_0() {
|
||||||
|
val scalar = Nd4j.scalar(Long.valueOf(1));
|
||||||
|
assertNotNull(scalar);
|
||||||
|
assertEquals(0, scalar.rank());
|
||||||
|
assertEquals(1, scalar.length());
|
||||||
|
assertEquals(DataType.LONG, scalar.dataType());
|
||||||
|
assertEquals(1.0, scalar.getInt(0), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicCreation_5_1() {
|
||||||
|
val scalar = Nd4j.scalar(Double.valueOf(1));
|
||||||
|
assertNotNull(scalar);
|
||||||
|
assertEquals(0, scalar.rank());
|
||||||
|
assertEquals(1, scalar.length());
|
||||||
|
assertEquals(DataType.DOUBLE, scalar.dataType());
|
||||||
|
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicCreation_5_2() {
|
||||||
|
val scalar = Nd4j.scalar(Float.valueOf(1));
|
||||||
|
assertNotNull(scalar);
|
||||||
|
assertEquals(0, scalar.rank());
|
||||||
|
assertEquals(1, scalar.length());
|
||||||
|
assertEquals(DataType.FLOAT, scalar.dataType());
|
||||||
|
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicCreation_5_3() {
|
||||||
|
val scalar = Nd4j.scalar(Short.valueOf((short) 1));
|
||||||
|
assertNotNull(scalar);
|
||||||
|
assertEquals(0, scalar.rank());
|
||||||
|
assertEquals(1, scalar.length());
|
||||||
|
assertEquals(DataType.SHORT, scalar.dataType());
|
||||||
|
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicCreation_5_4() {
|
||||||
|
val scalar = Nd4j.scalar(Byte.valueOf((byte) 1));
|
||||||
|
assertNotNull(scalar);
|
||||||
|
assertEquals(0, scalar.rank());
|
||||||
|
assertEquals(1, scalar.length());
|
||||||
|
assertEquals(DataType.BYTE, scalar.dataType());
|
||||||
|
assertEquals(1.0, scalar.getDouble(0), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicCreation_6() {
|
public void testBasicCreation_6() {
|
||||||
val scalar = Nd4j.scalar(1);
|
val scalar = Nd4j.scalar(1);
|
||||||
|
|
Loading…
Reference in New Issue