reduce bool changes (#118)
* reduce bool changes Signed-off-by: raver119 <raver119@gmail.com> * reduce bool tweaks Signed-off-by: raver119 <raver119@gmail.com>master
parent
425c747330
commit
a5f5ac72b1
|
@ -237,6 +237,8 @@ template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
__host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
|
nd4j_printf("Step A%i\n", -1);
|
||||||
|
|
||||||
if(shape::isEmpty(hXShapeInfo)) {
|
if(shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
if(shape::isEmpty(hZShapeInfo))
|
if(shape::isEmpty(hZShapeInfo))
|
||||||
|
@ -251,7 +253,8 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStrea
|
||||||
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||||
|
|
||||||
// scalar assign
|
// scalar assign
|
||||||
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
|
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
|
||||||
|
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed");
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||||
|
@ -274,6 +277,9 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaS
|
||||||
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw nd4j::cuda_exception::build("ReduceBoolFunction<X,Z>::intermediateScalar: failed to copy resulting scalar", res);
|
throw nd4j::cuda_exception::build("ReduceBoolFunction<X,Z>::intermediateScalar: failed to copy resulting scalar", res);
|
||||||
|
|
||||||
|
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolScalar empty(...) failed");
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
||||||
|
|
|
@ -249,7 +249,7 @@ __host__ void ReduceFloatFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStre
|
||||||
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||||
|
|
||||||
// scalar assign
|
// scalar assign
|
||||||
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShape, hXShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr);
|
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||||
|
|
|
@ -102,4 +102,7 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo
|
||||||
"with 2 inputs, second input (axis) must be an integer datatype for %s, got %s", getClass(), dataTypes);
|
"with 2 inputs, second input (axis) must be an integer datatype for %s, got %s", getClass(), dataTypes);
|
||||||
return Collections.singletonList(DataType.BOOL);
|
return Collections.singletonList(DataType.BOOL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public abstract boolean emptyValue();
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,10 @@ public class All extends BaseReduceBoolOp {
|
||||||
super(x);
|
super(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public All(INDArray x, int... axis) {
|
||||||
|
super(x, axis);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -65,4 +69,9 @@ public class All extends BaseReduceBoolOp {
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "All";
|
return "All";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean emptyValue() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,4 +65,9 @@ public class Any extends BaseReduceBoolOp {
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "Any";
|
return "Any";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean emptyValue() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,4 +71,8 @@ public class IsInf extends BaseReduceBoolOp {
|
||||||
return Collections.singletonList(f().zerosLike(arg()));
|
return Collections.singletonList(f().zerosLike(arg()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean emptyValue() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,4 +71,8 @@ public class IsNaN extends BaseReduceBoolOp {
|
||||||
return Collections.singletonList(f().zerosLike(arg()));
|
return Collections.singletonList(f().zerosLike(arg()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean emptyValue() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -935,6 +935,18 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME: this should be moved down to C++ on per-op basis
|
||||||
|
// reduce to scalar case, ReduceBool ops require special treatment
|
||||||
|
if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
|
||||||
|
if (op.z() == null) {
|
||||||
|
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
|
||||||
|
} else {
|
||||||
|
op.z().assign(((BaseReduceBoolOp) op).emptyValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
@ -994,9 +1006,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) {
|
//if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) {
|
||||||
return null;
|
// return null;
|
||||||
}
|
//}
|
||||||
|
|
||||||
val dataType = op.resultType();
|
val dataType = op.resultType();
|
||||||
|
|
||||||
|
|
|
@ -265,7 +265,18 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME: this should be moved down to C++ on per-op basis
|
||||||
val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector());
|
val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector());
|
||||||
|
// reduce to scalar case, ReduceBool ops require special treatment
|
||||||
|
if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
|
||||||
|
if (op.z() == null) {
|
||||||
|
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
|
||||||
|
} else {
|
||||||
|
op.z().assign(((BaseReduceBoolOp) op).emptyValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
return op.z();
|
||||||
|
}
|
||||||
|
|
||||||
//validateDataType(Nd4j.dataType(), op);
|
//validateDataType(Nd4j.dataType(), op);
|
||||||
|
|
||||||
|
|
|
@ -8134,6 +8134,36 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack);
|
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReduceAll_1() {
|
||||||
|
val x = Nd4j.empty(DataType.FLOAT);
|
||||||
|
val e = Nd4j.scalar(true);
|
||||||
|
val z = Nd4j.exec(new All(x));
|
||||||
|
|
||||||
|
assertEquals(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReduceAll_2() {
|
||||||
|
val x = Nd4j.ones(DataType.FLOAT, 0);
|
||||||
|
val e = Nd4j.scalar(true);
|
||||||
|
val z = Nd4j.exec(new All(x));
|
||||||
|
|
||||||
|
assertEquals(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReduceAll_3() {
|
||||||
|
val x = Nd4j.create(DataType.FLOAT, 0);
|
||||||
|
assertEquals(1, x.rank());
|
||||||
|
|
||||||
|
val e = Nd4j.scalar(true);
|
||||||
|
val z = Nd4j.exec(new All(x, 0));
|
||||||
|
|
||||||
|
assertEquals(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue