added keepDims to INDArray methods (#33)

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-06-27 17:39:07 -07:00 committed by AlexDBlack
parent 0ef373fe45
commit cc6063402e
14 changed files with 328 additions and 29 deletions

View File

@ -3879,6 +3879,19 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override
public INDArray normmax(boolean keepDims, int... dimension) {
validateNumericalArray("normmax", false);
return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension));
}
/**
* Returns the normmax along the specified dimension
*
@ -3887,8 +3900,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray normmax(int... dimension) {
validateNumericalArray("normmax", false);
return Nd4j.getExecutioner().exec(new NormMax(this, dimension));
return normmax(false, dimension);
}
/**
@ -4608,6 +4620,19 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return reshape(Nd4j.order(), shape);
}
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the product along the specified dimension
*/
@Override
public INDArray prod(boolean keepDims, int... dimension) {
validateNumericalArray("prod", false);
return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension));
}
/**
* Returns the product along a given dimension
*
@ -4616,8 +4641,20 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray prod(int... dimension) {
validateNumericalArray("prod", false);
return Nd4j.getExecutioner().exec(new Prod(this, dimension));
return prod(false, dimension);
}
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray mean(boolean keepDims, int... dimension) {
validateNumericalArray("mean", false);
return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension));
}
/**
@ -4628,8 +4665,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray mean(int... dimension) {
validateNumericalArray("mean", false);
return Nd4j.getExecutioner().exec(new Mean(this, dimension));
return mean(false, dimension);
}
@Override
@ -4639,9 +4675,14 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
@Override
public INDArray mean(@NonNull INDArray result, int... dimension) {
public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) {
validateNumericalArray("mean", false);
return Nd4j.getExecutioner().exec(new Mean(this, result, dimension));
return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension));
}
@Override
public INDArray mean(@NonNull INDArray result, int... dimension) {
return mean(result, false, dimension);
}
/**
@ -4669,6 +4710,19 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension));
}
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray max(boolean keepDims, int... dimension) {
validateNumericalArray("max", false);
return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension));
}
/**
* Returns the overall max of this ndarray
*
@ -4677,8 +4731,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray max(int... dimension) {
validateNumericalArray("max", false);
return Nd4j.getExecutioner().exec(new Max(this, dimension));
return max(false, dimension);
}
@Override
@ -4687,6 +4740,19 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new AMax(this, dimension));
}
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray min(boolean keepDims, int... dimension) {
validateNumericalArray("min", false);
return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension));
}
/**
* Returns the overall min of this ndarray
*
@ -4695,8 +4761,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray min(int... dimension) {
validateNumericalArray("min", false);
return Nd4j.getExecutioner().exec(new Min(this, dimension));
return min(false, dimension);
}
@Override
@ -4764,9 +4829,14 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
@Override
public INDArray sum(@NonNull INDArray result, int... dimension) {
public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) {
validateNumericalArray("sum", true);
return Nd4j.getExecutioner().exec(new Sum(this, result, dimension));
return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension));
}
@Override
public INDArray sum(@NonNull INDArray result, int... dimension) {
return sum(result, false, dimension);
}
@ -4778,8 +4848,21 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray norm1(int... dimension) {
return norm1(false, dimension);
}
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override
public INDArray norm1(boolean keepDims, int... dimension) {
validateNumericalArray("norm1", false);
return Nd4j.getExecutioner().exec(new Norm1(this, dimension));
return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension));
}
@ -4796,8 +4879,13 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray std(boolean biasCorrected, int... dimension) {
return std(biasCorrected, false, dimension);
}
@Override
public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) {
validateNumericalArray("std", false);
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, dimension));
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, keepDims, dimension));
}
@Override
@ -4806,6 +4894,19 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0);
}
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm2 along the specified dimension
*/
@Override
public INDArray norm2(boolean keepDims, int... dimension) {
validateNumericalArray("norm2", false);
return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension));
}
/**
* Returns the norm2 along the specified dimension
*
@ -4814,8 +4915,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
*/
@Override
public INDArray norm2(int... dimension) {
validateNumericalArray("norm2", false);
return Nd4j.getExecutioner().exec(new Norm2(this, dimension));
return norm2(false, dimension);
}

View File

@ -1250,6 +1250,58 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null;
}
@Override
public INDArray normmax(boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray norm2(boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray norm1(boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray prod(boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray mean(boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray mean(INDArray result, boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray max(boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray min(boolean keepDims, int... dimension) {
return null;
}
@Override
public INDArray sum(INDArray result, boolean keepDims, int... dimension) {
return null;
}
@Override
public void setShapeAndStride(int[] shape, int[] stride) {

View File

@ -1518,6 +1518,16 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray normmax(int... dimension);
/**
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
*
* @param dimension the dimension to the max norm along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return Max norm along the specified dimension
*/
INDArray normmax(boolean keepDims, int... dimension);
/**
* Return the max norm (aka infinity norm, equal to the maximum absolute value) for the entire array
*
@ -1533,6 +1543,15 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray norm2(int... dimension);
/**
* Returns the norm2 (L2 norm, sqrt(sum(x_i^2), also known as Euclidean norm) along the specified dimension(s)
*
* @param dimension the dimension to getScalar the norm2 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm2 along the specified dimension
*/
INDArray norm2(boolean keepDims, int... dimension);
/**
* Return the norm2 (L2 norm, sqrt(sum(x_i^2), also known as Euclidean norm) for the entire array
*
@ -1549,6 +1568,16 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray norm1(int... dimension);
/**
* Returns the norm1 (L1 norm, i.e., sum of absolute values; also known as Taxicab or Manhattan norm) along the
* specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
INDArray norm1(boolean keepDims, int... dimension);
/**
* Calculate and return norm1 (L1 norm, i.e., sum of absolute values; also known as Taxicab or Manhattan norm) for
* the entire array
@ -1580,6 +1609,15 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray std(boolean biasCorrected, int... dimension);
/**
* Standard deviation of an ndarray along a dimension
*
* @param dimension the dimension to getScalar the std along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the standard deviation along a particular dimension
*/
INDArray std(boolean biasCorrected, boolean keepDims, int... dimension);
/**
* Calculate the standard deviation for the entire array, specifying whether it is bias corrected or not
*
@ -1596,6 +1634,15 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray prod(int... dimension);
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the product along the specified dimension
*/
INDArray prod(boolean keepDims, int... dimension);
/**
* Calculate the product of all values in the array
*
@ -1619,11 +1666,29 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray mean(INDArray result, int... dimension);
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
INDArray mean(boolean keepDims, int... dimension);
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
INDArray mean(INDArray result, boolean keepDims, int... dimension);
/**
* Returns the absolute overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
* @return the absolute mean along the specified dimension of this ndarray
*/
INDArray amean(int... dimension);
@ -1644,8 +1709,8 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Returns the overall variance of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
* @param dimension the dimension to getScalar the variance along
* @return the variance along the specified dimension of this ndarray
*/
INDArray var(int... dimension);
@ -1653,8 +1718,8 @@ public interface INDArray extends Serializable, AutoCloseable {
* Returns the overall variance of this ndarray
*
* @param biasCorrected boolean on whether to apply corrected bias
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
* @param dimension the dimension to getScalar the variance along
* @return the variance along the specified dimension of this ndarray
*/
INDArray var(boolean biasCorrected, int... dimension);
@ -1668,16 +1733,25 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Returns the overall max of this ndarray along given dimensions
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
* @param dimension the dimension to getScalar the max along
* @return the max along the specified dimension of this ndarray
*/
INDArray max(int... dimension);
/**
* Returns the overall max of this ndarray along given dimensions
*
* @param dimension the dimension to getScalar the max along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the max along the specified dimension of this ndarray
*/
INDArray max(boolean keepDims, int... dimension);
/**
* Returns the absolute overall max of this ndarray along given dimensions
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
* @param dimension the dimension to getScalar the amax along
* @return the amax along the specified dimension of this ndarray
*/
INDArray amax(int... dimension);
@ -1696,11 +1770,20 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
* @param dimension the dimension to getScalar the min along
* @return the min along the specified dimension of this ndarray
*/
INDArray min(int... dimension);
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the min along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the min along the specified dimension of this ndarray
*/
INDArray min(boolean keepDims, int... dimension);
/**
* Returns minimum (absolute) value in this INDArray, along the specified dimensions
*
@ -1729,6 +1812,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray sum(int... dimension);
/**
* Returns the sum along the last dimension of this ndarray
*
* @param dimension the dimension to getScalar the sum along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the sum along the specified dimension of this ndarray
*/
INDArray sum(boolean keepDims, int... dimension);
/**
@ -1748,6 +1838,16 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray sum(INDArray result, int... dimension);
/**
* Returns the sum along the last dimension of this ndarray
*
* @param result result of this operation will be stored here
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @param dimension the dimension to getScalar the sum along
* @return the sum along the specified dimension of this ndarray
*/
INDArray sum(INDArray result, boolean keepDims, int... dimension);
/**
* Sum the entire array
* @return Sum of array

View File

@ -69,6 +69,11 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFl
}
public BaseReduceFloatOp(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
public BaseReduceFloatOp(INDArray x, int... dimensions) {
super(x, dimensions);
}

View File

@ -144,6 +144,11 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
this(x, null, dimensions);
}
public BaseReduceOp(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, dimensions);
this.keepDims = keepDims;
}
public BaseReduceOp(INDArray x, INDArray y, int... dimensions) {
this(x, y, null, dimensions);
}

View File

@ -58,6 +58,10 @@ public abstract class BaseReduceSameOp extends BaseReduceOp implements ReduceSam
super(x, dimensions);
}
public BaseReduceSameOp(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
protected BaseReduceSameOp() {
super();
}

View File

@ -45,6 +45,10 @@ public class Mean extends BaseReduceFloatOp {
super(x, dimensions);
}
public Mean(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
public Mean(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
super(x, z, keepDims, dimensions);
}

View File

@ -47,6 +47,10 @@ public class Norm1 extends BaseReduceFloatOp {
super(x, dimensions);
}
public Norm1(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
@Override
public INDArray noOp() {
return Transforms.abs(x());

View File

@ -47,6 +47,10 @@ public class Norm2 extends BaseReduceFloatOp {
super(x, dimensions);
}
public Norm2(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
@Override
public INDArray noOp() {
return Transforms.abs(x());

View File

@ -52,6 +52,10 @@ public class NormMax extends BaseReduceFloatOp {
super(x, dimensions);
}
public NormMax(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
@Override
public INDArray noOp() {
return Transforms.abs(x());

View File

@ -54,6 +54,9 @@ public class Max extends BaseReduceSameOp {
public Max(INDArray x, int... axis) {
super(x, null, null, axis);
}
public Max(INDArray x, boolean keepDims, int... axis) {
super(x, keepDims, axis);
}
public Max(INDArray x, INDArray z, int... axis) {
super(x, null, z, axis);

View File

@ -41,6 +41,10 @@ public class Min extends BaseReduceSameOp {
super(x, dimensions);
}
public Min(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
public Min(INDArray x, INDArray z, int... dimensions) {
super(x, null, z, dimensions);
}

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.same;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseReduceSameOp;
@ -53,6 +54,10 @@ public class Prod extends BaseReduceSameOp {
super(x, z, keepDims, dimensions);
}
public Prod(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
@Override
public int opNum() {

View File

@ -39,6 +39,11 @@ public class StandardDeviation extends Variance {
super(sameDiff, i_v, biasCorrected, keepDims, dimensions );
}
public StandardDeviation(INDArray x, boolean biasCorrected, boolean keepDims, int... dimension) {
super(x, biasCorrected, dimension);
this.keepDims = keepDims;
}
public StandardDeviation(INDArray x, boolean biasCorrected, int... dimension) {
super(x, biasCorrected, dimension);
}