Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-07-08 15:35:21 +09:00 committed by Alex Black
parent 66b7c3d6e3
commit 595656d01e
2 changed files with 55 additions and 30 deletions

View File

@ -28,7 +28,8 @@ import java.util.Arrays;
import java.util.List;
/**
* Dot product
* Dot product.
*
* @author Adam Gibson
*/
public class Dot extends BaseReduce3Op {
@ -37,21 +38,42 @@ public class Dot extends BaseReduce3Op {
super(sameDiff, i_v, i_v2, dimensions);
}
public Dot() {}
public Dot() {
}
/**
* Full array dot product reduction, optionally along specified dimensions.<br>
* See <a href="https://en.wikipedia.org/wiki/Dot_product">wikipedia</a> for details.
*
* @param x input variable.
* @param y input variable.
* @param z (optional) place holder for the result. Must have the expected shape.
* @param dimensions (optional) Dimensions to reduce over. If dimensions are not specified, full array reduction is performed.
* @see org.nd4j.linalg.ops.transforms.Transforms#dot Transforms.dot(...) for a wrapper around the common use case of 2 INDArrays.
*/
public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) {
this(x, y, z, true, false, dimensions);
}
public Dot(INDArray x, INDArray y, int... dimensions) {
/**
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
*/
public Dot(INDArray x, INDArray y, int... dimensions) {
this(x, y, null, dimensions);
}
/**
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
*/
public Dot(INDArray x, INDArray y, INDArray z) {
this(x, y, z, null);
}
public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){
/**
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
*/
public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) {
super(x, y, z, keepDims, false, dimensions);
}

View File

@ -56,7 +56,8 @@ import java.util.List;
public class Transforms {
private Transforms() {}
private Transforms() {
}
/**
* Cosine similarity
@ -64,7 +65,6 @@ public class Transforms {
* @param d1 the first vector
* @param d2 the second vector
* @return the cosine similarities between the 2 arrays
*
*/
public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) {
return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0);
@ -103,12 +103,21 @@ public class Transforms {
return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x));
}
public static INDArray dot(INDArray x, INDArray y){
return Nd4j.getExecutioner().exec(new Dot(x,y));
/**
* Dot product, new INDArray instance will be returned.<br>
* Note that the Nd4J design is different from Numpy. Numpy dot on 2d arrays is matrix multiplication. Nd4J is
* full array dot product reduction.
*
* @param x the first vector
* @param y the second vector
* @return the dot product between the 2 arrays
*/
public static INDArray dot(INDArray x, INDArray y) {
return Nd4j.getExecutioner().exec(new Dot(x, y));
}
public static INDArray cross(INDArray x, INDArray y){
Cross c = new Cross(x,y,null);
public static INDArray cross(INDArray x, INDArray y) {
Cross c = new Cross(x, y, null);
List<LongShapeDescriptor> shape = c.calculateOutputShape();
INDArray out = Nd4j.create(shape.get(0));
c.addOutputArgument(out);
@ -117,7 +126,6 @@ public class Transforms {
}
/**
*
* @param d1
* @param d2
* @return
@ -139,7 +147,6 @@ public class Transforms {
}
/**
*
* @param d1
* @param d2
* @return
@ -188,7 +195,6 @@ public class Transforms {
}
/**
* Returns the negative of an ndarray
*
@ -224,6 +230,7 @@ public class Transforms {
/**
* Ceiling function
*
* @param ndArray
* @param copyOnOps
* @return
@ -244,7 +251,6 @@ public class Transforms {
/**
*
* @param ndArray
* @param k
* @return
@ -255,6 +261,7 @@ public class Transforms {
/**
* Sin function
*
* @param in
* @return
*/
@ -264,6 +271,7 @@ public class Transforms {
/**
* Sin function
*
* @param in
* @param copy
* @return
@ -275,6 +283,7 @@ public class Transforms {
/**
* Sin function
*
* @param in
* @return
*/
@ -284,6 +293,7 @@ public class Transforms {
/**
* Sin function
*
* @param in
* @param copy
* @return
@ -294,6 +304,7 @@ public class Transforms {
/**
* Sinh function
*
* @param in
* @return
*/
@ -303,6 +314,7 @@ public class Transforms {
/**
* Sinh function
*
* @param in
* @param copy
* @return
@ -312,7 +324,6 @@ public class Transforms {
}
/**
*
* @param in
* @return
*/
@ -321,7 +332,6 @@ public class Transforms {
}
/**
*
* @param in
* @param copy
* @return
@ -331,7 +341,6 @@ public class Transforms {
}
/**
*
* @param in
* @return
*/
@ -340,7 +349,6 @@ public class Transforms {
}
/**
*
* @param in
* @param copy
* @return
@ -435,7 +443,6 @@ public class Transforms {
}
public static INDArray leakyRelu(INDArray arr, double cutoff) {
return leakyRelu(arr, cutoff, true);
}
@ -455,7 +462,6 @@ public class Transforms {
}
public static INDArray softPlus(INDArray arr) {
return softPlus(arr, true);
}
@ -495,14 +501,12 @@ public class Transforms {
}
public static INDArray softmax(INDArray arr) {
return softmax(arr, true);
}
/**
*
* @param in
* @param copy
* @return
@ -513,11 +517,12 @@ public class Transforms {
/**
* out = in * (1-in)
*
* @param in Input array
* @param copy If true: copy. False: apply in-place
* @return
*/
public static INDArray timesOneMinus(INDArray in, boolean copy){
public static INDArray timesOneMinus(INDArray in, boolean copy) {
return Nd4j.getExecutioner().exec(new TimesOneMinus(in, (copy ? in.ulike() : in)));
}
@ -534,6 +539,7 @@ public class Transforms {
/**
* Run the exp operation
*
* @param ndArray
* @return
*/
@ -558,7 +564,7 @@ public class Transforms {
return exec(dup ? new HardTanh(ndArray, ndArray.ulike()) : new HardTanh(ndArray));
}
public static INDArray hardSigmoid(INDArray arr, boolean dup){
public static INDArray hardSigmoid(INDArray arr, boolean dup) {
return Nd4j.getExecutioner().exec(new HardSigmoid(arr, (dup ? arr.ulike() : arr)));
}
@ -580,9 +586,7 @@ public class Transforms {
}
/**
*
* @param ndArray
* @return
*/
@ -1018,7 +1022,6 @@ public class Transforms {
}
/**
* Sqrt function
*
@ -1104,7 +1107,6 @@ public class Transforms {
}
/**
* Apply the given elementwise op
*
@ -1131,8 +1133,9 @@ public class Transforms {
* repeated squarings to minimize the number of mmul() operations needed
* <p>If <i>n</i> is zero, the identity matrix is returned.</p>
* <p>If <i>n</i> is negative, the matrix is inverted and raised to the abs(n) power.</p>
* @param in A square matrix to raise to an integer power, which will be changed if dup is false.
* @param n The integer power to raise the matrix to.
*
* @param in A square matrix to raise to an integer power, which will be changed if dup is false.
* @param n The integer power to raise the matrix to.
* @param dup If dup is true, the original input is unchanged.
* @return The result of raising <i>in</i> to the <i>n</i>th power.
*/