Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-09-03 13:06:42 +09:00 committed by Alex Black
parent 364a6e1a2a
commit c64b340975
3 changed files with 15 additions and 249 deletions

View File

@ -1195,7 +1195,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
return this;
}
@ -3089,12 +3088,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return mmuli(other, result);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override
public INDArray div(INDArray other) {
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
@ -3104,25 +3097,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* copy (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override
public INDArray div(INDArray other, INDArray result) {
validateNumericalArray("div", true);
return divi(other, result);
}
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the addition
*/
@Override
public INDArray mul(INDArray other) {
validateNumericalArray("mul", false);
@ -3134,24 +3114,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override
public INDArray mul(INDArray other, INDArray result) {
return muli(other, result);
}
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override
public INDArray sub(INDArray other) {
validateNumericalArray("sub", false);
@ -3162,24 +3129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override
public INDArray sub(INDArray other, INDArray result) {
return subi(other, result);
}
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override
public INDArray add(INDArray other) {
validateNumericalArray("add", false);
@ -3190,65 +3144,29 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray add(INDArray other, INDArray result) {
validateNumericalArray("add", false);
return addi(other, result);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param transpose the transpose status of each ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
validateNumericalArray("mmuli", false);
return dup().mmuli(other, this,transpose);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other) {
validateNumericalArray("mmuli", false);
return dup().mmuli(other, this);
}
/**
* Perform an in place matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
return transpose.exec(this, other, result);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other, INDArray result) {
validateNumericalArray("mmuli", false);
@ -3347,24 +3265,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.create(shape, stride);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override
public INDArray divi(INDArray other) {
return divi(other, this);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override
public INDArray divi(INDArray other, INDArray result) {
validateNumericalArray("divi", false);
@ -3373,24 +3278,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result;
}
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the multiplication
*/
@Override
public INDArray muli(INDArray other) {
return muli(other, this);
}
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override
public INDArray muli(INDArray other, INDArray result) {
validateNumericalArray("muli", false);
@ -3399,12 +3291,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result;
}
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override
public INDArray subi(INDArray other) {
return subi(other, this);
@ -3425,24 +3311,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result;
}
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other) {
return addi(other, this);
}
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other, INDArray result) {
validateNumericalArray("addi", false);
@ -3451,25 +3324,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result;
}
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override
public INDArray normmax(boolean keepDims, int... dimension) {
validateNumericalArray("normmax", false);
return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension));
}
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override
public INDArray normmax(int... dimension) {
return normmax(false, dimension);
@ -4071,49 +3931,23 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return reshape(Nd4j.order(), shape);
}
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the product along the specified dimension
*/
@Override
public INDArray prod(boolean keepDims, int... dimension) {
validateNumericalArray("prod", false);
return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension));
}
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @return the product along the specified dimension
*/
@Override
public INDArray prod(int... dimension) {
return prod(false, dimension);
}
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray mean(boolean keepDims, int... dimension) {
validateNumericalArray("mean", false);
return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension));
}
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray mean(int... dimension) {
return mean(false, dimension);
@ -4136,50 +3970,24 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return mean(result, false, dimension);
}
/**
* Returns the overall variance of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray var(int... dimension) {
validateNumericalArray("var", false);
return Nd4j.getExecutioner().exec(new Variance(this, dimension));
}
/**
* Returns the overall variance of this ndarray
*
* @param biasCorrected boolean on whether to apply corrected bias
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray var(boolean biasCorrected, int... dimension) {
validateNumericalArray("var", false);
return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension));
}
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray max(boolean keepDims, int... dimension) {
validateNumericalArray("max", false);
return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension));
}
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray max(int... dimension) {
return max(false, dimension);
@ -4191,25 +3999,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new AMax(this, dimension));
}
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray min(boolean keepDims, int... dimension) {
validateNumericalArray("min", false);
return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension));
}
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray min(int... dimension) {
return min(false, dimension);
@ -4290,39 +4085,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return sum(result, false, dimension);
}
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override
public INDArray norm1(int... dimension) {
return norm1(false, dimension);
}
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override
public INDArray norm1(boolean keepDims, int... dimension) {
validateNumericalArray("norm1", false);
return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension));
}
/**
* Standard deviation of an ndarray along a dimension
*
* @param dimension the dimension to getScalar the std along
* @return the standard deviation along a particular dimension
*/
@Override
public INDArray std(int... dimension) {
return std(true, dimension);
@ -4345,32 +4118,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0);
}
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm2 along the specified dimension
*/
@Override
public INDArray norm2(boolean keepDims, int... dimension) {
validateNumericalArray("norm2", false);
return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension));
}
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @return the norm2 along the specified dimension
*/
@Override
public INDArray norm2(int... dimension) {
return norm2(false, dimension);
}
/**
* Number of columns (shape[1]), throws an exception when
* called when not 2d

View File

@ -1232,8 +1232,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null;
}
@Override
public INDArray normmax(boolean keepDims, int... dimension) {
return null;

View File

@ -1404,7 +1404,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray add(INDArray other, INDArray result);
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param transpose the transpose status of each ndarray
* @return the result of the matrix multiplication
*/
INDArray mmuli(INDArray other, MMulTranspose transpose);
/**
@ -1415,7 +1421,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray mmuli(INDArray other);
/**
* Perform an in place matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose);
/**
@ -1497,7 +1509,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray addi(INDArray other, INDArray result);
/**
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
*
@ -1506,7 +1517,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray normmax(int... dimension);
/**
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
*
@ -1585,7 +1595,7 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Calculate the standard deviation for the entire array
*
* @return
* @return standard deviation
*/
Number stdNumber();