Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-09-03 13:06:42 +09:00 committed by Alex Black
parent 364a6e1a2a
commit c64b340975
3 changed files with 15 additions and 249 deletions

View File

@ -1195,7 +1195,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
return this; return this;
} }
@ -3089,12 +3088,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return mmuli(other, result); return mmuli(other, result);
} }
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override @Override
public INDArray div(INDArray other) { public INDArray div(INDArray other) {
if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
@ -3104,25 +3097,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override @Override
public INDArray div(INDArray other, INDArray result) { public INDArray div(INDArray other, INDArray result) {
validateNumericalArray("div", true); validateNumericalArray("div", true);
return divi(other, result); return divi(other, result);
} }
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the addition
*/
@Override @Override
public INDArray mul(INDArray other) { public INDArray mul(INDArray other) {
validateNumericalArray("mul", false); validateNumericalArray("mul", false);
@ -3134,24 +3114,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override @Override
public INDArray mul(INDArray other, INDArray result) { public INDArray mul(INDArray other, INDArray result) {
return muli(other, result); return muli(other, result);
} }
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override @Override
public INDArray sub(INDArray other) { public INDArray sub(INDArray other) {
validateNumericalArray("sub", false); validateNumericalArray("sub", false);
@ -3162,24 +3129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override @Override
public INDArray sub(INDArray other, INDArray result) { public INDArray sub(INDArray other, INDArray result) {
return subi(other, result); return subi(other, result);
} }
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override @Override
public INDArray add(INDArray other) { public INDArray add(INDArray other) {
validateNumericalArray("add", false); validateNumericalArray("add", false);
@ -3190,65 +3144,29 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override @Override
public INDArray add(INDArray other, INDArray result) { public INDArray add(INDArray other, INDArray result) {
validateNumericalArray("add", false); validateNumericalArray("add", false);
return addi(other, result); return addi(other, result);
} }
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param transpose the transpose status of each ndarray
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other, MMulTranspose transpose) { public INDArray mmuli(INDArray other, MMulTranspose transpose) {
validateNumericalArray("mmuli", false); validateNumericalArray("mmuli", false);
return dup().mmuli(other, this,transpose); return dup().mmuli(other, this,transpose);
} }
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other) { public INDArray mmuli(INDArray other) {
validateNumericalArray("mmuli", false); validateNumericalArray("mmuli", false);
return dup().mmuli(other, this); return dup().mmuli(other, this);
} }
/**
* Perform an in place matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
return transpose.exec(this, other, result); return transpose.exec(this, other, result);
} }
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmuli(INDArray other, INDArray result) { public INDArray mmuli(INDArray other, INDArray result) {
validateNumericalArray("mmuli", false); validateNumericalArray("mmuli", false);
@ -3347,24 +3265,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.create(shape, stride); return Nd4j.create(shape, stride);
} }
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override @Override
public INDArray divi(INDArray other) { public INDArray divi(INDArray other) {
return divi(other, this); return divi(other, this);
} }
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override @Override
public INDArray divi(INDArray other, INDArray result) { public INDArray divi(INDArray other, INDArray result) {
validateNumericalArray("divi", false); validateNumericalArray("divi", false);
@ -3373,24 +3278,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the multiplication
*/
@Override @Override
public INDArray muli(INDArray other) { public INDArray muli(INDArray other) {
return muli(other, this); return muli(other, this);
} }
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override @Override
public INDArray muli(INDArray other, INDArray result) { public INDArray muli(INDArray other, INDArray result) {
validateNumericalArray("muli", false); validateNumericalArray("muli", false);
@ -3399,12 +3291,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override @Override
public INDArray subi(INDArray other) { public INDArray subi(INDArray other) {
return subi(other, this); return subi(other, this);
@ -3425,24 +3311,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override @Override
public INDArray addi(INDArray other) { public INDArray addi(INDArray other) {
return addi(other, this); return addi(other, this);
} }
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override @Override
public INDArray addi(INDArray other, INDArray result) { public INDArray addi(INDArray other, INDArray result) {
validateNumericalArray("addi", false); validateNumericalArray("addi", false);
@ -3451,25 +3324,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray normmax(boolean keepDims, int... dimension) { public INDArray normmax(boolean keepDims, int... dimension) {
validateNumericalArray("normmax", false); validateNumericalArray("normmax", false);
return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension));
} }
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray normmax(int... dimension) { public INDArray normmax(int... dimension) {
return normmax(false, dimension); return normmax(false, dimension);
@ -4071,49 +3931,23 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return reshape(Nd4j.order(), shape); return reshape(Nd4j.order(), shape);
} }
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the product along the specified dimension
*/
@Override @Override
public INDArray prod(boolean keepDims, int... dimension) { public INDArray prod(boolean keepDims, int... dimension) {
validateNumericalArray("prod", false); validateNumericalArray("prod", false);
return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension));
} }
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @return the product along the specified dimension
*/
@Override @Override
public INDArray prod(int... dimension) { public INDArray prod(int... dimension) {
return prod(false, dimension); return prod(false, dimension);
} }
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray mean(boolean keepDims, int... dimension) { public INDArray mean(boolean keepDims, int... dimension) {
validateNumericalArray("mean", false); validateNumericalArray("mean", false);
return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension));
} }
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray mean(int... dimension) { public INDArray mean(int... dimension) {
return mean(false, dimension); return mean(false, dimension);
@ -4136,50 +3970,24 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return mean(result, false, dimension); return mean(result, false, dimension);
} }
/**
* Returns the overall variance of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray var(int... dimension) { public INDArray var(int... dimension) {
validateNumericalArray("var", false); validateNumericalArray("var", false);
return Nd4j.getExecutioner().exec(new Variance(this, dimension)); return Nd4j.getExecutioner().exec(new Variance(this, dimension));
} }
/**
* Returns the overall variance of this ndarray
*
* @param biasCorrected boolean on whether to apply corrected bias
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray var(boolean biasCorrected, int... dimension) { public INDArray var(boolean biasCorrected, int... dimension) {
validateNumericalArray("var", false); validateNumericalArray("var", false);
return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension));
} }
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray max(boolean keepDims, int... dimension) { public INDArray max(boolean keepDims, int... dimension) {
validateNumericalArray("max", false); validateNumericalArray("max", false);
return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension));
} }
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray max(int... dimension) { public INDArray max(int... dimension) {
return max(false, dimension); return max(false, dimension);
@ -4191,25 +3999,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new AMax(this, dimension)); return Nd4j.getExecutioner().exec(new AMax(this, dimension));
} }
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray min(boolean keepDims, int... dimension) { public INDArray min(boolean keepDims, int... dimension) {
validateNumericalArray("min", false); validateNumericalArray("min", false);
return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension));
} }
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override @Override
public INDArray min(int... dimension) { public INDArray min(int... dimension) {
return min(false, dimension); return min(false, dimension);
@ -4290,39 +4085,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return sum(result, false, dimension); return sum(result, false, dimension);
} }
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray norm1(int... dimension) { public INDArray norm1(int... dimension) {
return norm1(false, dimension); return norm1(false, dimension);
} }
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm1 along the specified dimension
*/
@Override @Override
public INDArray norm1(boolean keepDims, int... dimension) { public INDArray norm1(boolean keepDims, int... dimension) {
validateNumericalArray("norm1", false); validateNumericalArray("norm1", false);
return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension));
} }
/**
* Standard deviation of an ndarray along a dimension
*
* @param dimension the dimension to getScalar the std along
* @return the standard deviation along a particular dimension
*/
@Override @Override
public INDArray std(int... dimension) { public INDArray std(int... dimension) {
return std(true, dimension); return std(true, dimension);
@ -4345,32 +4118,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0);
} }
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @param keepDims whether to keep reduced dimensions as dimensions of size 1
* @return the norm2 along the specified dimension
*/
@Override @Override
public INDArray norm2(boolean keepDims, int... dimension) { public INDArray norm2(boolean keepDims, int... dimension) {
validateNumericalArray("norm2", false); validateNumericalArray("norm2", false);
return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension));
} }
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @return the norm2 along the specified dimension
*/
@Override @Override
public INDArray norm2(int... dimension) { public INDArray norm2(int... dimension) {
return norm2(false, dimension); return norm2(false, dimension);
} }
/** /**
* Number of columns (shape[1]), throws an exception when * Number of columns (shape[1]), throws an exception when
* called when not 2d * called when not 2d

View File

@ -1232,8 +1232,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null; return null;
} }
@Override @Override
public INDArray normmax(boolean keepDims, int... dimension) { public INDArray normmax(boolean keepDims, int... dimension) {
return null; return null;

View File

@ -1404,7 +1404,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray add(INDArray other, INDArray result); INDArray add(INDArray other, INDArray result);
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param transpose the transpose status of each ndarray
* @return the result of the matrix multiplication
*/
INDArray mmuli(INDArray other, MMulTranspose transpose); INDArray mmuli(INDArray other, MMulTranspose transpose);
/** /**
@ -1415,7 +1421,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray mmuli(INDArray other); INDArray mmuli(INDArray other);
/**
* Perform an in place matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose); INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose);
/** /**
@ -1497,7 +1509,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray addi(INDArray other, INDArray result); INDArray addi(INDArray other, INDArray result);
/** /**
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
* *
@ -1506,7 +1517,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray normmax(int... dimension); INDArray normmax(int... dimension);
/** /**
* Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s)
* *
@ -1585,7 +1595,7 @@ public interface INDArray extends Serializable, AutoCloseable {
/** /**
* Calculate the standard deviation for the entire array * Calculate the standard deviation for the entire array
* *
* @return * @return standard deviation
*/ */
Number stdNumber(); Number stdNumber();