parent
66b7c3d6e3
commit
595656d01e
|
@ -28,7 +28,8 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Dot product
|
||||
* Dot product.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class Dot extends BaseReduce3Op {
|
||||
|
@ -37,20 +38,41 @@ public class Dot extends BaseReduce3Op {
|
|||
super(sameDiff, i_v, i_v2, dimensions);
|
||||
}
|
||||
|
||||
public Dot() {}
|
||||
public Dot() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Full array dot product reduction, optionally along specified dimensions.<br>
|
||||
* See <a href="https://en.wikipedia.org/wiki/Dot_product">wikipedia</a> for details.
|
||||
*
|
||||
* @param x input variable.
|
||||
* @param y input variable.
|
||||
* @param z (optional) place holder for the result. Must have the expected shape.
|
||||
* @param dimensions (optional) Dimensions to reduce over. If dimensions are not specified, full array reduction is performed.
|
||||
* @see org.nd4j.linalg.ops.transforms.Transforms#dot Transforms.dot(...) for a wrapper around the common use case of 2 INDArrays.
|
||||
*/
|
||||
public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||
this(x, y, z, true, false, dimensions);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
|
||||
*/
|
||||
public Dot(INDArray x, INDArray y, int... dimensions) {
|
||||
this(x, y, null, dimensions);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
|
||||
*/
|
||||
public Dot(INDArray x, INDArray y, INDArray z) {
|
||||
this(x, y, z, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
|
||||
*/
|
||||
public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) {
|
||||
super(x, y, z, keepDims, false, dimensions);
|
||||
}
|
||||
|
|
|
@ -56,7 +56,8 @@ import java.util.List;
|
|||
public class Transforms {
|
||||
|
||||
|
||||
private Transforms() {}
|
||||
private Transforms() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Cosine similarity
|
||||
|
@ -64,7 +65,6 @@ public class Transforms {
|
|||
* @param d1 the first vector
|
||||
* @param d2 the second vector
|
||||
* @return the cosine similarities between the 2 arrays
|
||||
*
|
||||
*/
|
||||
public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) {
|
||||
return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0);
|
||||
|
@ -103,6 +103,15 @@ public class Transforms {
|
|||
return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x));
|
||||
}
|
||||
|
||||
/**
|
||||
* Dot product, new INDArray instance will be returned.<br>
|
||||
* Note that the Nd4J design is different from Numpy. Numpy dot on 2d arrays is matrix multiplication. Nd4J is
|
||||
* full array dot product reduction.
|
||||
*
|
||||
* @param x the first vector
|
||||
* @param y the second vector
|
||||
* @return the dot product between the 2 arrays
|
||||
*/
|
||||
public static INDArray dot(INDArray x, INDArray y) {
|
||||
return Nd4j.getExecutioner().exec(new Dot(x, y));
|
||||
}
|
||||
|
@ -117,7 +126,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param d1
|
||||
* @param d2
|
||||
* @return
|
||||
|
@ -139,7 +147,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param d1
|
||||
* @param d2
|
||||
* @return
|
||||
|
@ -188,7 +195,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Returns the negative of an ndarray
|
||||
*
|
||||
|
@ -224,6 +230,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Ceiling function
|
||||
*
|
||||
* @param ndArray
|
||||
* @param copyOnOps
|
||||
* @return
|
||||
|
@ -244,7 +251,6 @@ public class Transforms {
|
|||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param ndArray
|
||||
* @param k
|
||||
* @return
|
||||
|
@ -255,6 +261,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Sin function
|
||||
*
|
||||
* @param in
|
||||
* @return
|
||||
*/
|
||||
|
@ -264,6 +271,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Sin function
|
||||
*
|
||||
* @param in
|
||||
* @param copy
|
||||
* @return
|
||||
|
@ -275,6 +283,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Sin function
|
||||
*
|
||||
* @param in
|
||||
* @return
|
||||
*/
|
||||
|
@ -284,6 +293,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Sin function
|
||||
*
|
||||
* @param in
|
||||
* @param copy
|
||||
* @return
|
||||
|
@ -294,6 +304,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Sinh function
|
||||
*
|
||||
* @param in
|
||||
* @return
|
||||
*/
|
||||
|
@ -303,6 +314,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Sinh function
|
||||
*
|
||||
* @param in
|
||||
* @param copy
|
||||
* @return
|
||||
|
@ -312,7 +324,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param in
|
||||
* @return
|
||||
*/
|
||||
|
@ -321,7 +332,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param in
|
||||
* @param copy
|
||||
* @return
|
||||
|
@ -331,7 +341,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param in
|
||||
* @return
|
||||
*/
|
||||
|
@ -340,7 +349,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param in
|
||||
* @param copy
|
||||
* @return
|
||||
|
@ -435,7 +443,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
|
||||
|
||||
public static INDArray leakyRelu(INDArray arr, double cutoff) {
|
||||
return leakyRelu(arr, cutoff, true);
|
||||
}
|
||||
|
@ -455,7 +462,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
|
||||
|
||||
public static INDArray softPlus(INDArray arr) {
|
||||
return softPlus(arr, true);
|
||||
}
|
||||
|
@ -495,14 +501,12 @@ public class Transforms {
|
|||
}
|
||||
|
||||
|
||||
|
||||
public static INDArray softmax(INDArray arr) {
|
||||
return softmax(arr, true);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param in
|
||||
* @param copy
|
||||
* @return
|
||||
|
@ -513,6 +517,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* out = in * (1-in)
|
||||
*
|
||||
* @param in Input array
|
||||
* @param copy If true: copy. False: apply in-place
|
||||
* @return
|
||||
|
@ -534,6 +539,7 @@ public class Transforms {
|
|||
|
||||
/**
|
||||
* Run the exp operation
|
||||
*
|
||||
* @param ndArray
|
||||
* @return
|
||||
*/
|
||||
|
@ -580,9 +586,7 @@ public class Transforms {
|
|||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param ndArray
|
||||
* @return
|
||||
*/
|
||||
|
@ -1018,7 +1022,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Sqrt function
|
||||
*
|
||||
|
@ -1104,7 +1107,6 @@ public class Transforms {
|
|||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Apply the given elementwise op
|
||||
*
|
||||
|
@ -1131,6 +1133,7 @@ public class Transforms {
|
|||
* repeated squarings to minimize the number of mmul() operations needed
|
||||
* <p>If <i>n</i> is zero, the identity matrix is returned.</p>
|
||||
* <p>If <i>n</i> is negative, the matrix is inverted and raised to the abs(n) power.</p>
|
||||
*
|
||||
* @param in A square matrix to raise to an integer power, which will be changed if dup is false.
|
||||
* @param n The integer power to raise the matrix to.
|
||||
* @param dup If dup is true, the original input is unchanged.
|
||||
|
|
Loading…
Reference in New Issue