parent
66b7c3d6e3
commit
595656d01e
|
@ -28,7 +28,8 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dot product
|
* Dot product.
|
||||||
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
public class Dot extends BaseReduce3Op {
|
public class Dot extends BaseReduce3Op {
|
||||||
|
@ -37,21 +38,42 @@ public class Dot extends BaseReduce3Op {
|
||||||
super(sameDiff, i_v, i_v2, dimensions);
|
super(sameDiff, i_v, i_v2, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Dot() {}
|
public Dot() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Full array dot product reduction, optionally along specified dimensions.<br>
|
||||||
|
* See <a href="https://en.wikipedia.org/wiki/Dot_product">wikipedia</a> for details.
|
||||||
|
*
|
||||||
|
* @param x input variable.
|
||||||
|
* @param y input variable.
|
||||||
|
* @param z (optional) place holder for the result. Must have the expected shape.
|
||||||
|
* @param dimensions (optional) Dimensions to reduce over. If dimensions are not specified, full array reduction is performed.
|
||||||
|
* @see org.nd4j.linalg.ops.transforms.Transforms#dot Transforms.dot(...) for a wrapper around the common use case of 2 INDArrays.
|
||||||
|
*/
|
||||||
public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
this(x, y, z, true, false, dimensions);
|
this(x, y, z, true, false, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
|
||||||
|
*/
|
||||||
public Dot(INDArray x, INDArray y, int... dimensions) {
|
public Dot(INDArray x, INDArray y, int... dimensions) {
|
||||||
this(x, y, null, dimensions);
|
this(x, y, null, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
|
||||||
|
*/
|
||||||
public Dot(INDArray x, INDArray y, INDArray z) {
|
public Dot(INDArray x, INDArray y, INDArray z) {
|
||||||
this(x, y, z, null);
|
this(x, y, z, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){
|
/**
|
||||||
|
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
|
||||||
|
*/
|
||||||
|
public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) {
|
||||||
super(x, y, z, keepDims, false, dimensions);
|
super(x, y, z, keepDims, false, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,8 @@ import java.util.List;
|
||||||
public class Transforms {
|
public class Transforms {
|
||||||
|
|
||||||
|
|
||||||
private Transforms() {}
|
private Transforms() {
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cosine similarity
|
* Cosine similarity
|
||||||
|
@ -64,7 +65,6 @@ public class Transforms {
|
||||||
* @param d1 the first vector
|
* @param d1 the first vector
|
||||||
* @param d2 the second vector
|
* @param d2 the second vector
|
||||||
* @return the cosine similarities between the 2 arrays
|
* @return the cosine similarities between the 2 arrays
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) {
|
public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) {
|
||||||
return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0);
|
return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0);
|
||||||
|
@ -103,12 +103,21 @@ public class Transforms {
|
||||||
return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x));
|
return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray dot(INDArray x, INDArray y){
|
/**
|
||||||
return Nd4j.getExecutioner().exec(new Dot(x,y));
|
* Dot product, new INDArray instance will be returned.<br>
|
||||||
|
* Note that the Nd4J design is different from Numpy. Numpy dot on 2d arrays is matrix multiplication. Nd4J is
|
||||||
|
* full array dot product reduction.
|
||||||
|
*
|
||||||
|
* @param x the first vector
|
||||||
|
* @param y the second vector
|
||||||
|
* @return the dot product between the 2 arrays
|
||||||
|
*/
|
||||||
|
public static INDArray dot(INDArray x, INDArray y) {
|
||||||
|
return Nd4j.getExecutioner().exec(new Dot(x, y));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray cross(INDArray x, INDArray y){
|
public static INDArray cross(INDArray x, INDArray y) {
|
||||||
Cross c = new Cross(x,y,null);
|
Cross c = new Cross(x, y, null);
|
||||||
List<LongShapeDescriptor> shape = c.calculateOutputShape();
|
List<LongShapeDescriptor> shape = c.calculateOutputShape();
|
||||||
INDArray out = Nd4j.create(shape.get(0));
|
INDArray out = Nd4j.create(shape.get(0));
|
||||||
c.addOutputArgument(out);
|
c.addOutputArgument(out);
|
||||||
|
@ -117,7 +126,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param d1
|
* @param d1
|
||||||
* @param d2
|
* @param d2
|
||||||
* @return
|
* @return
|
||||||
|
@ -139,7 +147,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param d1
|
* @param d1
|
||||||
* @param d2
|
* @param d2
|
||||||
* @return
|
* @return
|
||||||
|
@ -188,7 +195,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the negative of an ndarray
|
* Returns the negative of an ndarray
|
||||||
*
|
*
|
||||||
|
@ -224,6 +230,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ceiling function
|
* Ceiling function
|
||||||
|
*
|
||||||
* @param ndArray
|
* @param ndArray
|
||||||
* @param copyOnOps
|
* @param copyOnOps
|
||||||
* @return
|
* @return
|
||||||
|
@ -244,7 +251,6 @@ public class Transforms {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param ndArray
|
* @param ndArray
|
||||||
* @param k
|
* @param k
|
||||||
* @return
|
* @return
|
||||||
|
@ -255,6 +261,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sin function
|
* Sin function
|
||||||
|
*
|
||||||
* @param in
|
* @param in
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -264,6 +271,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sin function
|
* Sin function
|
||||||
|
*
|
||||||
* @param in
|
* @param in
|
||||||
* @param copy
|
* @param copy
|
||||||
* @return
|
* @return
|
||||||
|
@ -275,6 +283,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sin function
|
* Sin function
|
||||||
|
*
|
||||||
* @param in
|
* @param in
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -284,6 +293,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sin function
|
* Sin function
|
||||||
|
*
|
||||||
* @param in
|
* @param in
|
||||||
* @param copy
|
* @param copy
|
||||||
* @return
|
* @return
|
||||||
|
@ -294,6 +304,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sinh function
|
* Sinh function
|
||||||
|
*
|
||||||
* @param in
|
* @param in
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -303,6 +314,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sinh function
|
* Sinh function
|
||||||
|
*
|
||||||
* @param in
|
* @param in
|
||||||
* @param copy
|
* @param copy
|
||||||
* @return
|
* @return
|
||||||
|
@ -312,7 +324,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param in
|
* @param in
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -321,7 +332,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param in
|
* @param in
|
||||||
* @param copy
|
* @param copy
|
||||||
* @return
|
* @return
|
||||||
|
@ -331,7 +341,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param in
|
* @param in
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -340,7 +349,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param in
|
* @param in
|
||||||
* @param copy
|
* @param copy
|
||||||
* @return
|
* @return
|
||||||
|
@ -435,7 +443,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public static INDArray leakyRelu(INDArray arr, double cutoff) {
|
public static INDArray leakyRelu(INDArray arr, double cutoff) {
|
||||||
return leakyRelu(arr, cutoff, true);
|
return leakyRelu(arr, cutoff, true);
|
||||||
}
|
}
|
||||||
|
@ -455,7 +462,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public static INDArray softPlus(INDArray arr) {
|
public static INDArray softPlus(INDArray arr) {
|
||||||
return softPlus(arr, true);
|
return softPlus(arr, true);
|
||||||
}
|
}
|
||||||
|
@ -495,14 +501,12 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public static INDArray softmax(INDArray arr) {
|
public static INDArray softmax(INDArray arr) {
|
||||||
return softmax(arr, true);
|
return softmax(arr, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param in
|
* @param in
|
||||||
* @param copy
|
* @param copy
|
||||||
* @return
|
* @return
|
||||||
|
@ -513,11 +517,12 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* out = in * (1-in)
|
* out = in * (1-in)
|
||||||
|
*
|
||||||
* @param in Input array
|
* @param in Input array
|
||||||
* @param copy If true: copy. False: apply in-place
|
* @param copy If true: copy. False: apply in-place
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray timesOneMinus(INDArray in, boolean copy){
|
public static INDArray timesOneMinus(INDArray in, boolean copy) {
|
||||||
return Nd4j.getExecutioner().exec(new TimesOneMinus(in, (copy ? in.ulike() : in)));
|
return Nd4j.getExecutioner().exec(new TimesOneMinus(in, (copy ? in.ulike() : in)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,6 +539,7 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run the exp operation
|
* Run the exp operation
|
||||||
|
*
|
||||||
* @param ndArray
|
* @param ndArray
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -558,7 +564,7 @@ public class Transforms {
|
||||||
return exec(dup ? new HardTanh(ndArray, ndArray.ulike()) : new HardTanh(ndArray));
|
return exec(dup ? new HardTanh(ndArray, ndArray.ulike()) : new HardTanh(ndArray));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray hardSigmoid(INDArray arr, boolean dup){
|
public static INDArray hardSigmoid(INDArray arr, boolean dup) {
|
||||||
return Nd4j.getExecutioner().exec(new HardSigmoid(arr, (dup ? arr.ulike() : arr)));
|
return Nd4j.getExecutioner().exec(new HardSigmoid(arr, (dup ? arr.ulike() : arr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -580,9 +586,7 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param ndArray
|
* @param ndArray
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -1018,7 +1022,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sqrt function
|
* Sqrt function
|
||||||
*
|
*
|
||||||
|
@ -1104,7 +1107,6 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply the given elementwise op
|
* Apply the given elementwise op
|
||||||
*
|
*
|
||||||
|
@ -1131,6 +1133,7 @@ public class Transforms {
|
||||||
* repeated squarings to minimize the number of mmul() operations needed
|
* repeated squarings to minimize the number of mmul() operations needed
|
||||||
* <p>If <i>n</i> is zero, the identity matrix is returned.</p>
|
* <p>If <i>n</i> is zero, the identity matrix is returned.</p>
|
||||||
* <p>If <i>n</i> is negative, the matrix is inverted and raised to the abs(n) power.</p>
|
* <p>If <i>n</i> is negative, the matrix is inverted and raised to the abs(n) power.</p>
|
||||||
|
*
|
||||||
* @param in A square matrix to raise to an integer power, which will be changed if dup is false.
|
* @param in A square matrix to raise to an integer power, which will be changed if dup is false.
|
||||||
* @param n The integer power to raise the matrix to.
|
* @param n The integer power to raise the matrix to.
|
||||||
* @param dup If dup is true, the original input is unchanged.
|
* @param dup If dup is true, the original input is unchanged.
|
||||||
|
|
Loading…
Reference in New Issue