reduce bool changes (#118)

* reduce bool changes

Signed-off-by: raver119 <raver119@gmail.com>

* reduce bool tweaks

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-09 20:08:59 +03:00 committed by GitHub
parent 425c747330
commit a5f5ac72b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 89 additions and 5 deletions

View File

@ -237,6 +237,8 @@ template <typename X, typename Z>
template<typename OpType>
__host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
nd4j_printf("Step A%i\n", -1);
if(shape::isEmpty(hXShapeInfo)) {
if(shape::isEmpty(hZShapeInfo))
@ -251,7 +253,8 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStrea
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
// scalar assign
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed");
}
else {
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
@ -274,6 +277,9 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaS
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
if (res != 0)
throw nd4j::cuda_exception::build("ReduceBoolFunction<X,Z>::intermediateScalar: failed to copy resulting scalar", res);
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolScalar empty(...) failed");
}
else {
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);

View File

@ -249,7 +249,7 @@ __host__ void ReduceFloatFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStre
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
// scalar assign
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShape, hXShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr);
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr);
}
else {
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);

View File

@ -102,4 +102,7 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo
"with 2 inputs, second input (axis) must be an integer datatype for %s, got %s", getClass(), dataTypes);
return Collections.singletonList(DataType.BOOL);
}
public abstract boolean emptyValue();
}

View File

@ -41,6 +41,10 @@ public class All extends BaseReduceBoolOp {
super(x);
}
public All(INDArray x, int... axis) {
super(x, axis);
}
@Override
public int opNum() {
return 1;
@ -65,4 +69,9 @@ public class All extends BaseReduceBoolOp {
public String tensorflowName() {
return "All";
}
@Override
public boolean emptyValue() {
return true;
}
}

View File

@ -65,4 +65,9 @@ public class Any extends BaseReduceBoolOp {
public String tensorflowName() {
return "Any";
}
@Override
public boolean emptyValue() {
return false;
}
}

View File

@ -71,4 +71,8 @@ public class IsInf extends BaseReduceBoolOp {
return Collections.singletonList(f().zerosLike(arg()));
}
@Override
public boolean emptyValue() {
return false;
}
}

View File

@ -71,4 +71,8 @@ public class IsNaN extends BaseReduceBoolOp {
return Collections.singletonList(f().zerosLike(arg()));
}
@Override
public boolean emptyValue() {
return false;
}
}

View File

@ -935,6 +935,18 @@ public class CudaExecutioner extends DefaultOpExecutioner {
}
}
// FIXME: this should be moved down to C++ on per-op basis
// reduce to scalar case, ReduceBool ops require special treatment
if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
if (op.z() == null) {
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
} else {
op.z().assign(((BaseReduceBoolOp) op).emptyValue());
}
return context;
}
long st = profilingConfigurableHookIn(op);
checkForCompression(op);
@ -994,9 +1006,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
}
}
if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) {
return null;
}
//if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) {
// return null;
//}
val dataType = op.resultType();

View File

@ -265,7 +265,18 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
}
}
// FIXME: this should be moved down to C++ on per-op basis
val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector());
// reduce to scalar case, ReduceBool ops require special treatment
if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
if (op.z() == null) {
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
} else {
op.z().assign(((BaseReduceBoolOp) op).emptyValue());
}
return op.z();
}
//validateDataType(Nd4j.dataType(), op);

View File

@ -8134,6 +8134,36 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack);
}
@Test
public void testReduceAll_1() {
val x = Nd4j.empty(DataType.FLOAT);
val e = Nd4j.scalar(true);
val z = Nd4j.exec(new All(x));
assertEquals(e, z);
}
@Test
public void testReduceAll_2() {
val x = Nd4j.ones(DataType.FLOAT, 0);
val e = Nd4j.scalar(true);
val z = Nd4j.exec(new All(x));
assertEquals(e, z);
}
@Test
public void testReduceAll_3() {
val x = Nd4j.create(DataType.FLOAT, 0);
assertEquals(1, x.rank());
val e = Nd4j.scalar(true);
val z = Nd4j.exec(new All(x, 0));
assertEquals(e, z);
}
@Override
public char ordering() {
return 'c';