From cc6063402ebac85ae9ca70d9b5b141df7de2376f Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 27 Jun 2019 17:39:07 -0700 Subject: [PATCH] added keepDims to INDArray methods (#33) Signed-off-by: Ryan Nett --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 136 +++++++++++++++--- .../linalg/api/ndarray/BaseSparseNDArray.java | 52 +++++++ .../org/nd4j/linalg/api/ndarray/INDArray.java | 122 ++++++++++++++-- .../linalg/api/ops/BaseReduceFloatOp.java | 5 + .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 5 + .../nd4j/linalg/api/ops/BaseReduceSameOp.java | 4 + .../api/ops/impl/reduce/floating/Mean.java | 4 + .../api/ops/impl/reduce/floating/Norm1.java | 4 + .../api/ops/impl/reduce/floating/Norm2.java | 4 + .../api/ops/impl/reduce/floating/NormMax.java | 4 + .../linalg/api/ops/impl/reduce/same/Max.java | 3 + .../linalg/api/ops/impl/reduce/same/Min.java | 4 + .../linalg/api/ops/impl/reduce/same/Prod.java | 5 + .../impl/summarystats/StandardDeviation.java | 5 + 14 files changed, 328 insertions(+), 29 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 8924aa2bf..44cf0c233 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -3879,6 +3879,19 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } + /** + * Returns the normmax along the specified dimension + * + * @param dimension the dimension to getScalar the norm1 along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the norm1 along the specified dimension + */ + @Override + public INDArray normmax(boolean keepDims, int... dimension) { + validateNumericalArray("normmax", false); + return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); + } + /** * Returns the normmax along the specified dimension * @@ -3887,8 +3900,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray normmax(int... dimension) { - validateNumericalArray("normmax", false); - return Nd4j.getExecutioner().exec(new NormMax(this, dimension)); + return normmax(false, dimension); } /** @@ -4608,6 +4620,19 @@ public abstract class BaseNDArray implements INDArray, Iterable { return reshape(Nd4j.order(), shape); } + /** + * Returns the product along a given dimension + * + * @param dimension the dimension to getScalar the product along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the product along the specified dimension + */ + @Override + public INDArray prod(boolean keepDims, int... dimension) { + validateNumericalArray("prod", false); + return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); + } + /** * Returns the product along a given dimension * @@ -4616,8 +4641,20 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray prod(int... dimension) { - validateNumericalArray("prod", false); - return Nd4j.getExecutioner().exec(new Prod(this, dimension)); + return prod(false, dimension); + } + + /** + * Returns the overall mean of this ndarray + * + * @param dimension the dimension to getScalar the mean along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the mean along the specified dimension of this ndarray + */ + @Override + public INDArray mean(boolean keepDims, int... dimension) { + validateNumericalArray("mean", false); + return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); } /** @@ -4628,8 +4665,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray mean(int... dimension) { - validateNumericalArray("mean", false); - return Nd4j.getExecutioner().exec(new Mean(this, dimension)); + return mean(false, dimension); } @Override @@ -4639,9 +4675,14 @@ public abstract class BaseNDArray implements INDArray, Iterable { } @Override - public INDArray mean(@NonNull INDArray result, int... dimension) { + public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) { validateNumericalArray("mean", false); - return Nd4j.getExecutioner().exec(new Mean(this, result, dimension)); + return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension)); + } + + @Override + public INDArray mean(@NonNull INDArray result, int... dimension) { + return mean(result, false, dimension); } /** @@ -4669,6 +4710,19 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); } + /** + * Returns the overall max of this ndarray + * + * @param dimension the dimension to getScalar the mean along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the mean along the specified dimension of this ndarray + */ + @Override + public INDArray max(boolean keepDims, int... dimension) { + validateNumericalArray("max", false); + return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); + } + /** * Returns the overall max of this ndarray * @@ -4677,8 +4731,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray max(int... dimension) { - validateNumericalArray("max", false); - return Nd4j.getExecutioner().exec(new Max(this, dimension)); + return max(false, dimension); } @Override @@ -4687,6 +4740,19 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new AMax(this, dimension)); } + /** + * Returns the overall min of this ndarray + * + * @param dimension the dimension to getScalar the mean along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the mean along the specified dimension of this ndarray + */ + @Override + public INDArray min(boolean keepDims, int... dimension) { + validateNumericalArray("min", false); + return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); + } + /** * Returns the overall min of this ndarray * @@ -4695,8 +4761,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray min(int... dimension) { - validateNumericalArray("min", false); - return Nd4j.getExecutioner().exec(new Min(this, dimension)); + return min(false, dimension); } @Override @@ -4764,9 +4829,14 @@ public abstract class BaseNDArray implements INDArray, Iterable { } @Override - public INDArray sum(@NonNull INDArray result, int... dimension) { + public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) { validateNumericalArray("sum", true); - return Nd4j.getExecutioner().exec(new Sum(this, result, dimension)); + return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension)); + } + + @Override + public INDArray sum(@NonNull INDArray result, int... dimension) { + return sum(result, false, dimension); } @@ -4778,8 +4848,21 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray norm1(int... dimension) { + return norm1(false, dimension); + } + + + /** + * Returns the norm1 along the specified dimension + * + * @param dimension the dimension to getScalar the norm1 along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the norm1 along the specified dimension + */ + @Override + public INDArray norm1(boolean keepDims, int... dimension) { validateNumericalArray("norm1", false); - return Nd4j.getExecutioner().exec(new Norm1(this, dimension)); + return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); } @@ -4796,8 +4879,13 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray std(boolean biasCorrected, int... dimension) { + return std(biasCorrected, false, dimension); + } + + @Override + public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) { validateNumericalArray("std", false); - return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, dimension)); + return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, keepDims, dimension)); } @Override @@ -4806,6 +4894,19 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); } + /** + * Returns the norm2 along the specified dimension + * + * @param dimension the dimension to getScalar the norm2 along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the norm2 along the specified dimension + */ + @Override + public INDArray norm2(boolean keepDims, int... dimension) { + validateNumericalArray("norm2", false); + return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); + } + /** * Returns the norm2 along the specified dimension * @@ -4814,8 +4915,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray norm2(int... dimension) { - validateNumericalArray("norm2", false); - return Nd4j.getExecutioner().exec(new Norm2(this, dimension)); + return norm2(false, dimension); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 10a3fc5c7..8dadaaa82 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -1250,6 +1250,58 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } + + + @Override + public INDArray normmax(boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray norm2(boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray norm1(boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray prod(boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray mean(boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray mean(INDArray result, boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray max(boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray min(boolean keepDims, int... dimension) { + return null; + } + + @Override + public INDArray sum(INDArray result, boolean keepDims, int... dimension) { + return null; + } + @Override public void setShapeAndStride(int[] shape, int[] stride) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index a6a307e52..351b4659e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1518,6 +1518,16 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray normmax(int... dimension); + + /** + * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) + * + * @param dimension the dimension to the max norm along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return Max norm along the specified dimension + */ + INDArray normmax(boolean keepDims, int... dimension); + /** * Return the max norm (aka infinity norm, equal to the maximum absolute value) for the entire array * @@ -1533,6 +1543,15 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray norm2(int... dimension); + /** + * Returns the norm2 (L2 norm, sqrt(sum(x_i^2), also known as Euclidean norm) along the specified dimension(s) + * + * @param dimension the dimension to getScalar the norm2 along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the norm2 along the specified dimension + */ + INDArray norm2(boolean keepDims, int... dimension); + /** * Return the norm2 (L2 norm, sqrt(sum(x_i^2), also known as Euclidean norm) for the entire array * @@ -1549,6 +1568,16 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray norm1(int... dimension); + /** + * Returns the norm1 (L1 norm, i.e., sum of absolute values; also known as Taxicab or Manhattan norm) along the + * specified dimension + * + * @param dimension the dimension to getScalar the norm1 along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the norm1 along the specified dimension + */ + INDArray norm1(boolean keepDims, int... dimension); + /** * Calculate and return norm1 (L1 norm, i.e., sum of absolute values; also known as Taxicab or Manhattan norm) for * the entire array @@ -1580,6 +1609,15 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray std(boolean biasCorrected, int... dimension); + /** + * Standard deviation of an ndarray along a dimension + * + * @param dimension the dimension to getScalar the std along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the standard deviation along a particular dimension + */ + INDArray std(boolean biasCorrected, boolean keepDims, int... dimension); + /** * Calculate the standard deviation for the entire array, specifying whether it is bias corrected or not * @@ -1596,6 +1634,15 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray prod(int... dimension); + /** + * Returns the product along a given dimension + * + * @param dimension the dimension to getScalar the product along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the product along the specified dimension + */ + INDArray prod(boolean keepDims, int... dimension); + /** * Calculate the product of all values in the array * @@ -1619,11 +1666,29 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray mean(INDArray result, int... dimension); + /** + * Returns the overall mean of this ndarray + * + * @param dimension the dimension to getScalar the mean along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the mean along the specified dimension of this ndarray + */ + INDArray mean(boolean keepDims, int... dimension); + + /** + * Returns the overall mean of this ndarray + * + * @param dimension the dimension to getScalar the mean along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the mean along the specified dimension of this ndarray + */ + INDArray mean(INDArray result, boolean keepDims, int... dimension); + /** * Returns the absolute overall mean of this ndarray * * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray + * @return the absolute mean along the specified dimension of this ndarray */ INDArray amean(int... dimension); @@ -1644,8 +1709,8 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Returns the overall variance of this ndarray * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray + * @param dimension the dimension to getScalar the variance along + * @return the variance along the specified dimension of this ndarray */ INDArray var(int... dimension); @@ -1653,8 +1718,8 @@ public interface INDArray extends Serializable, AutoCloseable { * Returns the overall variance of this ndarray * * @param biasCorrected boolean on whether to apply corrected bias - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray + * @param dimension the dimension to getScalar the variance along + * @return the variance along the specified dimension of this ndarray */ INDArray var(boolean biasCorrected, int... dimension); @@ -1668,16 +1733,25 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Returns the overall max of this ndarray along given dimensions * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray + * @param dimension the dimension to getScalar the max along + * @return the max along the specified dimension of this ndarray */ INDArray max(int... dimension); + /** + * Returns the overall max of this ndarray along given dimensions + * + * @param dimension the dimension to getScalar the max along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the max along the specified dimension of this ndarray + */ + INDArray max(boolean keepDims, int... dimension); + /** * Returns the absolute overall max of this ndarray along given dimensions * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray + * @param dimension the dimension to getScalar the amax along + * @return the amax along the specified dimension of this ndarray */ INDArray amax(int... dimension); @@ -1696,11 +1770,20 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Returns the overall min of this ndarray * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray + * @param dimension the dimension to getScalar the min along + * @return the min along the specified dimension of this ndarray */ INDArray min(int... dimension); + /** + * Returns the overall min of this ndarray + * + * @param dimension the dimension to getScalar the min along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the min along the specified dimension of this ndarray + */ + INDArray min(boolean keepDims, int... dimension); + /** * Returns minimum (absolute) value in this INDArray, along the specified dimensions * @@ -1729,6 +1812,13 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray sum(int... dimension); + /** + * Returns the sum along the last dimension of this ndarray + * + * @param dimension the dimension to getScalar the sum along + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @return the sum along the specified dimension of this ndarray + */ INDArray sum(boolean keepDims, int... dimension); /** @@ -1748,6 +1838,16 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray sum(INDArray result, int... dimension); + /** + * Returns the sum along the last dimension of this ndarray + * + * @param result result of this operation will be stored here + * @param keepDims whether to keep reduced dimensions as dimensions of size 1 + * @param dimension the dimension to getScalar the sum along + * @return the sum along the specified dimension of this ndarray + */ + INDArray sum(INDArray result, boolean keepDims, int... dimension); + /** * Sum the entire array * @return Sum of array diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java index fbebe0c9e..7158d9001 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java @@ -69,6 +69,11 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFl } + public BaseReduceFloatOp(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + + public BaseReduceFloatOp(INDArray x, int... dimensions) { super(x, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 039e91b53..25a04125d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -144,6 +144,11 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this(x, null, dimensions); } + public BaseReduceOp(INDArray x, boolean keepDims, int... dimensions) { + this(x, null, dimensions); + this.keepDims = keepDims; + } + public BaseReduceOp(INDArray x, INDArray y, int... dimensions) { this(x, y, null, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java index ab6c8b377..41d4a2446 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java @@ -58,6 +58,10 @@ public abstract class BaseReduceSameOp extends BaseReduceOp implements ReduceSam super(x, dimensions); } + public BaseReduceSameOp(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + protected BaseReduceSameOp() { super(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java index 853ccae24..bf15f94d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java @@ -45,6 +45,10 @@ public class Mean extends BaseReduceFloatOp { super(x, dimensions); } + public Mean(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + public Mean(INDArray x, INDArray z, boolean keepDims, int... dimensions) { super(x, z, keepDims, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java index b0c9a0a14..a2ba88927 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java @@ -47,6 +47,10 @@ public class Norm1 extends BaseReduceFloatOp { super(x, dimensions); } + public Norm1(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + @Override public INDArray noOp() { return Transforms.abs(x()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java index f498e5277..00bd2dbd1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java @@ -47,6 +47,10 @@ public class Norm2 extends BaseReduceFloatOp { super(x, dimensions); } + public Norm2(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + @Override public INDArray noOp() { return Transforms.abs(x()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java index d5886e1f5..ece542857 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java @@ -52,6 +52,10 @@ public class NormMax extends BaseReduceFloatOp { super(x, dimensions); } + public NormMax(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + @Override public INDArray noOp() { return Transforms.abs(x()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java index 509a1fa4e..a29384a42 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java @@ -54,6 +54,9 @@ public class Max extends BaseReduceSameOp { public Max(INDArray x, int... axis) { super(x, null, null, axis); } + public Max(INDArray x, boolean keepDims, int... axis) { + super(x, keepDims, axis); + } public Max(INDArray x, INDArray z, int... axis) { super(x, null, z, axis); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java index 8169b0438..99c1e038b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java @@ -41,6 +41,10 @@ public class Min extends BaseReduceSameOp { super(x, dimensions); } + public Min(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + public Min(INDArray x, INDArray z, int... dimensions) { super(x, null, z, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java index b735b15cf..b5073d0f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.same; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.BaseNDArray; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; @@ -53,6 +54,10 @@ public class Prod extends BaseReduceSameOp { super(x, z, keepDims, dimensions); } + public Prod(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + @Override public int opNum() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java index 567f6cae7..211fec834 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java @@ -39,6 +39,11 @@ public class StandardDeviation extends Variance { super(sameDiff, i_v, biasCorrected, keepDims, dimensions ); } + public StandardDeviation(INDArray x, boolean biasCorrected, boolean keepDims, int... dimension) { + super(x, biasCorrected, dimension); + this.keepDims = keepDims; + } + public StandardDeviation(INDArray x, boolean biasCorrected, int... dimension) { super(x, biasCorrected, dimension); }