parent
							
								
									66b7c3d6e3
								
							
						
					
					
						commit
						595656d01e
					
				| @ -28,7 +28,8 @@ import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * Dot product | ||||
|  * Dot product. | ||||
|  * | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class Dot extends BaseReduce3Op { | ||||
| @ -37,20 +38,41 @@ public class Dot extends BaseReduce3Op { | ||||
|         super(sameDiff, i_v, i_v2, dimensions); | ||||
|     } | ||||
| 
 | ||||
|     public Dot() {} | ||||
|     public Dot() { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Full array dot product reduction, optionally along specified dimensions.<br> | ||||
|      * See <a href="https://en.wikipedia.org/wiki/Dot_product">wikipedia</a> for details. | ||||
|      * | ||||
|      * @param x          input variable. | ||||
|      * @param y          input variable. | ||||
|      * @param z          (optional) place holder for the result. Must have the expected shape. | ||||
|      * @param dimensions (optional) Dimensions to reduce over. If dimensions are not specified, full array reduction is performed. | ||||
|      * @see org.nd4j.linalg.ops.transforms.Transforms#dot Transforms.dot(...) for a wrapper around the common use case of 2 INDArrays. | ||||
|      */ | ||||
|     public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) { | ||||
|         this(x, y, z, true, false, dimensions); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * @see #Dot(INDArray x, INDArray y, INDArray z, int...) | ||||
|      */ | ||||
|     public Dot(INDArray x, INDArray y, int... dimensions) { | ||||
|         this(x, y, null, dimensions); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @see #Dot(INDArray x, INDArray y, INDArray z, int...) | ||||
|      */ | ||||
|     public Dot(INDArray x, INDArray y, INDArray z) { | ||||
|         this(x, y, z, null); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @see #Dot(INDArray x, INDArray y, INDArray z, int...) | ||||
|      */ | ||||
|     public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { | ||||
|         super(x, y, z, keepDims, false, dimensions); | ||||
|     } | ||||
|  | ||||
| @ -56,7 +56,8 @@ import java.util.List; | ||||
| public class Transforms { | ||||
| 
 | ||||
| 
 | ||||
|     private Transforms() {} | ||||
|     private Transforms() { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Cosine similarity | ||||
| @ -64,7 +65,6 @@ public class Transforms { | ||||
|      * @param d1 the first vector | ||||
|      * @param d2 the second vector | ||||
|      * @return the cosine similarities between the 2 arrays | ||||
|      * | ||||
|      */ | ||||
|     public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) { | ||||
|         return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0); | ||||
| @ -103,6 +103,15 @@ public class Transforms { | ||||
|         return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x)); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Dot product, new INDArray instance will be returned.<br> | ||||
|      * Note that the Nd4J design is different from Numpy. Numpy dot on 2d arrays is matrix multiplication. Nd4J is | ||||
|      * full array dot product reduction. | ||||
|      * | ||||
|      * @param x the first vector | ||||
|      * @param y the second vector | ||||
|      * @return the dot product between the 2 arrays | ||||
|      */ | ||||
|     public static INDArray dot(INDArray x, INDArray y) { | ||||
|         return Nd4j.getExecutioner().exec(new Dot(x, y)); | ||||
|     } | ||||
| @ -117,7 +126,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param d1 | ||||
|      * @param d2 | ||||
|      * @return | ||||
| @ -139,7 +147,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param d1 | ||||
|      * @param d2 | ||||
|      * @return | ||||
| @ -188,7 +195,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the negative of an ndarray | ||||
|      * | ||||
| @ -224,6 +230,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Ceiling function | ||||
|      * | ||||
|      * @param ndArray | ||||
|      * @param copyOnOps | ||||
|      * @return | ||||
| @ -244,7 +251,6 @@ public class Transforms { | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param ndArray | ||||
|      * @param k | ||||
|      * @return | ||||
| @ -255,6 +261,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Sin function | ||||
|      * | ||||
|      * @param in | ||||
|      * @return | ||||
|      */ | ||||
| @ -264,6 +271,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Sin function | ||||
|      * | ||||
|      * @param in | ||||
|      * @param copy | ||||
|      * @return | ||||
| @ -275,6 +283,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Sin function | ||||
|      * | ||||
|      * @param in | ||||
|      * @return | ||||
|      */ | ||||
| @ -284,6 +293,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Sin function | ||||
|      * | ||||
|      * @param in | ||||
|      * @param copy | ||||
|      * @return | ||||
| @ -294,6 +304,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Sinh function | ||||
|      * | ||||
|      * @param in | ||||
|      * @return | ||||
|      */ | ||||
| @ -303,6 +314,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Sinh function | ||||
|      * | ||||
|      * @param in | ||||
|      * @param copy | ||||
|      * @return | ||||
| @ -312,7 +324,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param in | ||||
|      * @return | ||||
|      */ | ||||
| @ -321,7 +332,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param in | ||||
|      * @param copy | ||||
|      * @return | ||||
| @ -331,7 +341,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param in | ||||
|      * @return | ||||
|      */ | ||||
| @ -340,7 +349,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param in | ||||
|      * @param copy | ||||
|      * @return | ||||
| @ -435,7 +443,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public static INDArray leakyRelu(INDArray arr, double cutoff) { | ||||
|         return leakyRelu(arr, cutoff, true); | ||||
|     } | ||||
| @ -455,7 +462,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public static INDArray softPlus(INDArray arr) { | ||||
|         return softPlus(arr, true); | ||||
|     } | ||||
| @ -495,14 +501,12 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public static INDArray softmax(INDArray arr) { | ||||
|         return softmax(arr, true); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param in | ||||
|      * @param copy | ||||
|      * @return | ||||
| @ -513,6 +517,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * out = in * (1-in) | ||||
|      * | ||||
|      * @param in   Input array | ||||
|      * @param copy If true: copy. False: apply in-place | ||||
|      * @return | ||||
| @ -534,6 +539,7 @@ public class Transforms { | ||||
| 
 | ||||
|     /** | ||||
|      * Run the exp operation | ||||
|      * | ||||
|      * @param ndArray | ||||
|      * @return | ||||
|      */ | ||||
| @ -580,9 +586,7 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param ndArray | ||||
|      * @return | ||||
|      */ | ||||
| @ -1018,7 +1022,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Sqrt function | ||||
|      * | ||||
| @ -1104,7 +1107,6 @@ public class Transforms { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Apply the given elementwise op | ||||
|      * | ||||
| @ -1131,6 +1133,7 @@ public class Transforms { | ||||
|      * repeated squarings to minimize the number of mmul() operations needed | ||||
|      * <p>If <i>n</i> is zero, the identity matrix is returned.</p> | ||||
|      * <p>If <i>n</i> is negative, the matrix is inverted and raised to the abs(n) power.</p> | ||||
|      * | ||||
|      * @param in  A square matrix to raise to an integer power, which will be changed if dup is false. | ||||
|      * @param n   The integer power to raise the matrix to. | ||||
|      * @param dup If dup is true, the original input is unchanged. | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user