From 595656d01e9ec4be4efa03326b54facc27f549ec Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Mon, 8 Jul 2019 15:35:21 +0900 Subject: [PATCH] fix #7947 (#7985) Signed-off-by: Robert Altena --- .../nd4j/linalg/api/ops/impl/reduce3/Dot.java | 30 ++++++++-- .../linalg/ops/transforms/Transforms.java | 55 ++++++++++--------- 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java index 7c00c7933..acfe9a0a3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java @@ -28,7 +28,8 @@ import java.util.Arrays; import java.util.List; /** - * Dot product + * Dot product. + * * @author Adam Gibson */ public class Dot extends BaseReduce3Op { @@ -37,21 +38,42 @@ public class Dot extends BaseReduce3Op { super(sameDiff, i_v, i_v2, dimensions); } - public Dot() {} + public Dot() { + } + /** + * Full array dot product reduction, optionally along specified dimensions.
+ * See wikipedia for details. + * + * @param x input variable. + * @param y input variable. + * @param z (optional) place holder for the result. Must have the expected shape. + * @param dimensions (optional) Dimensions to reduce over. If dimensions are not specified, full array reduction is performed. + * @see org.nd4j.linalg.ops.transforms.Transforms#dot Transforms.dot(...) for a wrapper around the common use case of 2 INDArrays. + */ public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) { this(x, y, z, true, false, dimensions); } - public Dot(INDArray x, INDArray y, int... dimensions) { + + /** + * @see #Dot(INDArray x, INDArray y, INDArray z, int...) + */ + public Dot(INDArray x, INDArray y, int... dimensions) { this(x, y, null, dimensions); } + /** + * @see #Dot(INDArray x, INDArray y, INDArray z, int...) + */ public Dot(INDArray x, INDArray y, INDArray z) { this(x, y, z, null); } - public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ + /** + * @see #Dot(INDArray x, INDArray y, INDArray z, int...) + */ + public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { super(x, y, z, keepDims, false, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index 421376f94..43df4a430 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -56,7 +56,8 @@ import java.util.List; public class Transforms { - private Transforms() {} + private Transforms() { + } /** * Cosine similarity @@ -64,7 +65,6 @@ public class Transforms { * @param d1 the first vector * @param d2 the second vector * @return the cosine similarities between the 2 arrays - * */ public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) { return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0); @@ -103,12 +103,21 @@ public class Transforms { return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x)); } - public static INDArray dot(INDArray x, INDArray y){ - return Nd4j.getExecutioner().exec(new Dot(x,y)); + /** + * Dot product, new INDArray instance will be returned.
+ * Note that the Nd4J design is different from Numpy. Numpy dot on 2d arrays is matrix multiplication. Nd4J is + * full array dot product reduction. + * + * @param x the first vector + * @param y the second vector + * @return the dot product between the 2 arrays + */ + public static INDArray dot(INDArray x, INDArray y) { + return Nd4j.getExecutioner().exec(new Dot(x, y)); } - public static INDArray cross(INDArray x, INDArray y){ - Cross c = new Cross(x,y,null); + public static INDArray cross(INDArray x, INDArray y) { + Cross c = new Cross(x, y, null); List shape = c.calculateOutputShape(); INDArray out = Nd4j.create(shape.get(0)); c.addOutputArgument(out); @@ -117,7 +126,6 @@ public class Transforms { } /** - * * @param d1 * @param d2 * @return @@ -139,7 +147,6 @@ public class Transforms { } /** - * * @param d1 * @param d2 * @return @@ -188,7 +195,6 @@ public class Transforms { } - /** * Returns the negative of an ndarray * @@ -224,6 +230,7 @@ public class Transforms { /** * Ceiling function + * * @param ndArray * @param copyOnOps * @return @@ -244,7 +251,6 @@ public class Transforms { /** - * * @param ndArray * @param k * @return @@ -255,6 +261,7 @@ public class Transforms { /** * Sin function + * * @param in * @return */ @@ -264,6 +271,7 @@ public class Transforms { /** * Sin function + * * @param in * @param copy * @return @@ -275,6 +283,7 @@ public class Transforms { /** * Sin function + * * @param in * @return */ @@ -284,6 +293,7 @@ public class Transforms { /** * Sin function + * * @param in * @param copy * @return @@ -294,6 +304,7 @@ public class Transforms { /** * Sinh function + * * @param in * @return */ @@ -303,6 +314,7 @@ public class Transforms { /** * Sinh function + * * @param in * @param copy * @return @@ -312,7 +324,6 @@ public class Transforms { } /** - * * @param in * @return */ @@ -321,7 +332,6 @@ public class Transforms { } /** - * * @param in * @param copy * @return @@ -331,7 +341,6 @@ public class Transforms { } /** - * * @param in * @return */ @@ -340,7 +349,6 @@ public class Transforms { } /** - * * @param in * @param copy * @return @@ -435,7 +443,6 @@ public class Transforms { } - public static INDArray leakyRelu(INDArray arr, double cutoff) { return leakyRelu(arr, cutoff, true); } @@ -455,7 +462,6 @@ public class Transforms { } - public static INDArray softPlus(INDArray arr) { return softPlus(arr, true); } @@ -495,14 +501,12 @@ public class Transforms { } - public static INDArray softmax(INDArray arr) { return softmax(arr, true); } /** - * * @param in * @param copy * @return @@ -513,11 +517,12 @@ public class Transforms { /** * out = in * (1-in) + * * @param in Input array * @param copy If true: copy. False: apply in-place * @return */ - public static INDArray timesOneMinus(INDArray in, boolean copy){ + public static INDArray timesOneMinus(INDArray in, boolean copy) { return Nd4j.getExecutioner().exec(new TimesOneMinus(in, (copy ? in.ulike() : in))); } @@ -534,6 +539,7 @@ public class Transforms { /** * Run the exp operation + * * @param ndArray * @return */ @@ -558,7 +564,7 @@ public class Transforms { return exec(dup ? new HardTanh(ndArray, ndArray.ulike()) : new HardTanh(ndArray)); } - public static INDArray hardSigmoid(INDArray arr, boolean dup){ + public static INDArray hardSigmoid(INDArray arr, boolean dup) { return Nd4j.getExecutioner().exec(new HardSigmoid(arr, (dup ? arr.ulike() : arr))); } @@ -580,9 +586,7 @@ public class Transforms { } - /** - * * @param ndArray * @return */ @@ -1018,7 +1022,6 @@ public class Transforms { } - /** * Sqrt function * @@ -1104,7 +1107,6 @@ public class Transforms { } - /** * Apply the given elementwise op * @@ -1131,8 +1133,9 @@ public class Transforms { * repeated squarings to minimize the number of mmul() operations needed *

If n is zero, the identity matrix is returned.

*

If n is negative, the matrix is inverted and raised to the abs(n) power.

- * @param in A square matrix to raise to an integer power, which will be changed if dup is false. - * @param n The integer power to raise the matrix to. + * + * @param in A square matrix to raise to an integer power, which will be changed if dup is false. + * @param n The integer power to raise the matrix to. * @param dup If dup is true, the original input is unchanged. * @return The result of raising in to the nth power. */