Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-07-08 15:35:21 +09:00 committed by Alex Black
parent 66b7c3d6e3
commit 595656d01e
2 changed files with 55 additions and 30 deletions

View File

@ -28,7 +28,8 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
/** /**
* Dot product * Dot product.
*
* @author Adam Gibson * @author Adam Gibson
*/ */
public class Dot extends BaseReduce3Op { public class Dot extends BaseReduce3Op {
@ -37,21 +38,42 @@ public class Dot extends BaseReduce3Op {
super(sameDiff, i_v, i_v2, dimensions); super(sameDiff, i_v, i_v2, dimensions);
} }
public Dot() {} public Dot() {
}
/**
* Full array dot product reduction, optionally along specified dimensions.<br>
* See <a href="https://en.wikipedia.org/wiki/Dot_product">wikipedia</a> for details.
*
* @param x input variable.
* @param y input variable.
* @param z (optional) place holder for the result. Must have the expected shape.
* @param dimensions (optional) Dimensions to reduce over. If dimensions are not specified, full array reduction is performed.
* @see org.nd4j.linalg.ops.transforms.Transforms#dot Transforms.dot(...) for a wrapper around the common use case of 2 INDArrays.
*/
public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) { public Dot(INDArray x, INDArray y, INDArray z, int... dimensions) {
this(x, y, z, true, false, dimensions); this(x, y, z, true, false, dimensions);
} }
public Dot(INDArray x, INDArray y, int... dimensions) {
/**
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
*/
public Dot(INDArray x, INDArray y, int... dimensions) {
this(x, y, null, dimensions); this(x, y, null, dimensions);
} }
/**
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
*/
public Dot(INDArray x, INDArray y, INDArray z) { public Dot(INDArray x, INDArray y, INDArray z) {
this(x, y, z, null); this(x, y, z, null);
} }
public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ /**
* @see #Dot(INDArray x, INDArray y, INDArray z, int...)
*/
public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) {
super(x, y, z, keepDims, false, dimensions); super(x, y, z, keepDims, false, dimensions);
} }

View File

@ -56,7 +56,8 @@ import java.util.List;
public class Transforms { public class Transforms {
private Transforms() {} private Transforms() {
}
/** /**
* Cosine similarity * Cosine similarity
@ -64,7 +65,6 @@ public class Transforms {
* @param d1 the first vector * @param d1 the first vector
* @param d2 the second vector * @param d2 the second vector
* @return the cosine similarities between the 2 arrays * @return the cosine similarities between the 2 arrays
*
*/ */
public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) { public static double cosineSim(@NonNull INDArray d1, @NonNull INDArray d2) {
return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0); return Nd4j.getExecutioner().exec(new CosineSimilarity(d1, d2)).getDouble(0);
@ -103,12 +103,21 @@ public class Transforms {
return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x)); return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x));
} }
public static INDArray dot(INDArray x, INDArray y){ /**
return Nd4j.getExecutioner().exec(new Dot(x,y)); * Dot product, new INDArray instance will be returned.<br>
* Note that the Nd4J design is different from Numpy. Numpy dot on 2d arrays is matrix multiplication. Nd4J is
* full array dot product reduction.
*
* @param x the first vector
* @param y the second vector
* @return the dot product between the 2 arrays
*/
public static INDArray dot(INDArray x, INDArray y) {
return Nd4j.getExecutioner().exec(new Dot(x, y));
} }
public static INDArray cross(INDArray x, INDArray y){ public static INDArray cross(INDArray x, INDArray y) {
Cross c = new Cross(x,y,null); Cross c = new Cross(x, y, null);
List<LongShapeDescriptor> shape = c.calculateOutputShape(); List<LongShapeDescriptor> shape = c.calculateOutputShape();
INDArray out = Nd4j.create(shape.get(0)); INDArray out = Nd4j.create(shape.get(0));
c.addOutputArgument(out); c.addOutputArgument(out);
@ -117,7 +126,6 @@ public class Transforms {
} }
/** /**
*
* @param d1 * @param d1
* @param d2 * @param d2
* @return * @return
@ -139,7 +147,6 @@ public class Transforms {
} }
/** /**
*
* @param d1 * @param d1
* @param d2 * @param d2
* @return * @return
@ -188,7 +195,6 @@ public class Transforms {
} }
/** /**
* Returns the negative of an ndarray * Returns the negative of an ndarray
* *
@ -224,6 +230,7 @@ public class Transforms {
/** /**
* Ceiling function * Ceiling function
*
* @param ndArray * @param ndArray
* @param copyOnOps * @param copyOnOps
* @return * @return
@ -244,7 +251,6 @@ public class Transforms {
/** /**
*
* @param ndArray * @param ndArray
* @param k * @param k
* @return * @return
@ -255,6 +261,7 @@ public class Transforms {
/** /**
* Sin function * Sin function
*
* @param in * @param in
* @return * @return
*/ */
@ -264,6 +271,7 @@ public class Transforms {
/** /**
* Sin function * Sin function
*
* @param in * @param in
* @param copy * @param copy
* @return * @return
@ -275,6 +283,7 @@ public class Transforms {
/** /**
* Sin function * Sin function
*
* @param in * @param in
* @return * @return
*/ */
@ -284,6 +293,7 @@ public class Transforms {
/** /**
* Sin function * Sin function
*
* @param in * @param in
* @param copy * @param copy
* @return * @return
@ -294,6 +304,7 @@ public class Transforms {
/** /**
* Sinh function * Sinh function
*
* @param in * @param in
* @return * @return
*/ */
@ -303,6 +314,7 @@ public class Transforms {
/** /**
* Sinh function * Sinh function
*
* @param in * @param in
* @param copy * @param copy
* @return * @return
@ -312,7 +324,6 @@ public class Transforms {
} }
/** /**
*
* @param in * @param in
* @return * @return
*/ */
@ -321,7 +332,6 @@ public class Transforms {
} }
/** /**
*
* @param in * @param in
* @param copy * @param copy
* @return * @return
@ -331,7 +341,6 @@ public class Transforms {
} }
/** /**
*
* @param in * @param in
* @return * @return
*/ */
@ -340,7 +349,6 @@ public class Transforms {
} }
/** /**
*
* @param in * @param in
* @param copy * @param copy
* @return * @return
@ -435,7 +443,6 @@ public class Transforms {
} }
public static INDArray leakyRelu(INDArray arr, double cutoff) { public static INDArray leakyRelu(INDArray arr, double cutoff) {
return leakyRelu(arr, cutoff, true); return leakyRelu(arr, cutoff, true);
} }
@ -455,7 +462,6 @@ public class Transforms {
} }
public static INDArray softPlus(INDArray arr) { public static INDArray softPlus(INDArray arr) {
return softPlus(arr, true); return softPlus(arr, true);
} }
@ -495,14 +501,12 @@ public class Transforms {
} }
public static INDArray softmax(INDArray arr) { public static INDArray softmax(INDArray arr) {
return softmax(arr, true); return softmax(arr, true);
} }
/** /**
*
* @param in * @param in
* @param copy * @param copy
* @return * @return
@ -513,11 +517,12 @@ public class Transforms {
/** /**
* out = in * (1-in) * out = in * (1-in)
*
* @param in Input array * @param in Input array
* @param copy If true: copy. False: apply in-place * @param copy If true: copy. False: apply in-place
* @return * @return
*/ */
public static INDArray timesOneMinus(INDArray in, boolean copy){ public static INDArray timesOneMinus(INDArray in, boolean copy) {
return Nd4j.getExecutioner().exec(new TimesOneMinus(in, (copy ? in.ulike() : in))); return Nd4j.getExecutioner().exec(new TimesOneMinus(in, (copy ? in.ulike() : in)));
} }
@ -534,6 +539,7 @@ public class Transforms {
/** /**
* Run the exp operation * Run the exp operation
*
* @param ndArray * @param ndArray
* @return * @return
*/ */
@ -558,7 +564,7 @@ public class Transforms {
return exec(dup ? new HardTanh(ndArray, ndArray.ulike()) : new HardTanh(ndArray)); return exec(dup ? new HardTanh(ndArray, ndArray.ulike()) : new HardTanh(ndArray));
} }
public static INDArray hardSigmoid(INDArray arr, boolean dup){ public static INDArray hardSigmoid(INDArray arr, boolean dup) {
return Nd4j.getExecutioner().exec(new HardSigmoid(arr, (dup ? arr.ulike() : arr))); return Nd4j.getExecutioner().exec(new HardSigmoid(arr, (dup ? arr.ulike() : arr)));
} }
@ -580,9 +586,7 @@ public class Transforms {
} }
/** /**
*
* @param ndArray * @param ndArray
* @return * @return
*/ */
@ -1018,7 +1022,6 @@ public class Transforms {
} }
/** /**
* Sqrt function * Sqrt function
* *
@ -1104,7 +1107,6 @@ public class Transforms {
} }
/** /**
* Apply the given elementwise op * Apply the given elementwise op
* *
@ -1131,8 +1133,9 @@ public class Transforms {
* repeated squarings to minimize the number of mmul() operations needed * repeated squarings to minimize the number of mmul() operations needed
* <p>If <i>n</i> is zero, the identity matrix is returned.</p> * <p>If <i>n</i> is zero, the identity matrix is returned.</p>
* <p>If <i>n</i> is negative, the matrix is inverted and raised to the abs(n) power.</p> * <p>If <i>n</i> is negative, the matrix is inverted and raised to the abs(n) power.</p>
* @param in A square matrix to raise to an integer power, which will be changed if dup is false. *
* @param n The integer power to raise the matrix to. * @param in A square matrix to raise to an integer power, which will be changed if dup is false.
* @param n The integer power to raise the matrix to.
* @param dup If dup is true, the original input is unchanged. * @param dup If dup is true, the original input is unchanged.
* @return The result of raising <i>in</i> to the <i>n</i>th power. * @return The result of raising <i>in</i> to the <i>n</i>th power.
*/ */