Fix javadoc and cleanup

master
Brian Rosenberger 2022-10-21 15:19:32 +02:00
parent 66ed10a5e3
commit 1c2ca75308
67 changed files with 9896 additions and 8781 deletions

View File

@ -56,6 +56,7 @@ configurations.all {
} }
allprojects { Project proj -> allprojects { Project proj ->
apply plugin: 'com.google.osdetector' apply plugin: 'com.google.osdetector'
@ -162,3 +163,21 @@ allprojects { Project proj ->
} }
} }
} }
task aggregatedJavadocs(type: Javadoc, description: 'Generate javadocs from all child projects as if it was a single project', group: 'Documentation') {
subprojects.each { proj ->
proj.tasks.withType(Javadoc).each { javadocTask ->
logger.quiet("Adding javadoc for project " + proj.name)
source += javadocTask.source
classpath += javadocTask.classpath
excludes += javadocTask.excludes
includes += javadocTask.includes
}
}
destinationDir = file("$buildDir/docs/javadoc")
title = "$project.name $version API"
options.author true
options.links 'http://docs.oracle.com/javase/8/docs/api/'
options.addStringOption('Xdoclint:none', '-quiet')
}

View File

@ -142,3 +142,7 @@ groupId:artifactId:packaging:classifier:version
In your case it should work with In your case it should work with
edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0
Native cpu code under linux needs libc6-dev
/lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found

View File

@ -90,7 +90,7 @@ public class Operands {
/** /**
* This method returns array identified its numeric id * This method returns array identified its numeric id
* @param name * @param id
* @return * @return
*/ */
public INDArray getById(int id) { public INDArray getById(int id) {
@ -99,7 +99,8 @@ public class Operands {
/** /**
* This method returns array identified its numeric id and index * This method returns array identified its numeric id and index
* @param name * @param id
* @param index
* @return * @return
*/ */
public INDArray getById(int id, int index) { public INDArray getById(int id, int index) {
@ -121,7 +122,7 @@ public class Operands {
} }
/** /**
* This method returns contents of this entity as collection of key->value pairs * This method returns contents of this entity as collection of key->value pairs
* @return * @return
*/ */
public Collection<Pair<NodeDescriptor, INDArray>> asCollection() { public Collection<Pair<NodeDescriptor, INDArray>> asCollection() {

View File

@ -50,7 +50,7 @@ public class ExecDebuggingListener extends BaseListener {
/** /**
* @param printMode Print mode, see {@link PrintMode} * @param printMode Print mode, see {@link PrintMode}
* @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations" * @param maxIterations Maximum number of iterations to print. &lt;= 0 for "all iterations"
* @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output * @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output
*/ */
public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){ public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){

View File

@ -573,7 +573,7 @@ public class SameDiff extends SDBaseOps {
} }
/** /**
* Get the function by the {@link DifferentialFunction#getOwnName()} * Get the function by the {@link org.nd4j.autodiff.functions.DifferentialFunction#getOwnName()}
* *
* @param id the id of the function * @param id the id of the function
* @return the function for the given id if it exists * @return the function for the given id if it exists
@ -1348,9 +1348,9 @@ public class SameDiff extends SDBaseOps {
/** /**
* Get the names of variables (if any) that have been marked as loss variables to be minimized.<br> * Get the names of variables (if any) that have been marked as loss variables to be minimized.<br>
* Variables can be marked as loss variables in a few different ways:<br> * Variables can be marked as loss variables in a few different ways:<br>
* (a) Losses are automatically added when creating loss functions via {@link #sd()}<br> * (a) Losses are automatically added when creating loss functions via {@link SameDiff#sd}<br>
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}<br> * (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}<br>
* (c) Via {@link TrainingConfig#setLossVariables(List)}<br> * (c) Via {@link org.nd4j.autodiff.samediff.TrainingConfig#setLossVariables(List)}<br>
*/ */
public List<String> getLossVariables() { public List<String> getLossVariables() {
return Collections.unmodifiableList(this.lossVariables); return Collections.unmodifiableList(this.lossVariables);

View File

@ -54,12 +54,12 @@ public abstract class AbstractDependencyTracker<T, D> {
} }
/** /**
* @return A new map where the dependents (i.e., Y in "X -> Y") are the key * @return A new map where the dependents (i.e., Y in "X -&gt; Y") are the key
*/ */
protected abstract Map<T, ?> newTMap(); protected abstract Map<T, ?> newTMap();
/** /**
* @return A new set where the dependents (i.e., Y in "X -> Y") are the key * @return A new set where the dependents (i.e., Y in "X -&gt; Y") are the key
*/ */
protected abstract Set<T> newTSet(); protected abstract Set<T> newTSet();
@ -103,7 +103,7 @@ public abstract class AbstractDependencyTracker<T, D> {
/** /**
* Mark the specified value as satisfied. * Mark the specified value as satisfied.
* For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true) * For example, if two dependencies have been previously added (X -&gt; Y) and (X -&gt; A) then after the markSatisfied(X, true)
* call, both of these dependencies are considered satisfied. * call, both of these dependencies are considered satisfied.
* *
* @param x Value to mark * @param x Value to mark
@ -191,7 +191,7 @@ public abstract class AbstractDependencyTracker<T, D> {
} }
/** /**
* Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)} * Check whether any dependencies x -&gt; y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)}
* or {@link #addOrDependency(Object, Object, Object)} * or {@link #addOrDependency(Object, Object, Object)}
* *
* @param y Dependent to check * @param y Dependent to check
@ -207,7 +207,7 @@ public abstract class AbstractDependencyTracker<T, D> {
} }
/** /**
* Get all dependencies x, for x -> y, and (x1 or x2) -> y * Get all dependencies x, for x -&gt; y, and (x1 or x2) -&gt; y
* *
* @param y Dependent to get dependencies for * @param y Dependent to get dependencies for
* @return List of dependencies * @return List of dependencies
@ -223,7 +223,7 @@ public abstract class AbstractDependencyTracker<T, D> {
} }
/** /**
* Add a dependency: y depends on x, as in x -> y * Add a dependency: y depends on x, as in x -&gt; y
* *
* @param y The dependent * @param y The dependent
* @param x The dependee that is required for Y * @param x The dependee that is required for Y
@ -302,7 +302,7 @@ public abstract class AbstractDependencyTracker<T, D> {
/** /**
* Remove a dependency (x -> y) * Remove a dependency (x -&gt; y)
* *
* @param y The dependent that currently requires X * @param y The dependent that currently requires X
* @param x The dependee that is no longer required for Y * @param x The dependee that is no longer required for Y
@ -357,7 +357,7 @@ public abstract class AbstractDependencyTracker<T, D> {
} }
/** /**
* Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y<br> * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -&gt; Y<br>
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the * If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the
* dependency is considered satisfied * dependency is considered satisfied
* *
@ -382,16 +382,16 @@ public abstract class AbstractDependencyTracker<T, D> {
} }
/** /**
* @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X-&gt;Y)
*/ */
public boolean hasNewAllSatisfied() { public boolean hasNewAllSatisfied() {
return !allSatisfiedQueue.isEmpty(); return !allSatisfiedQueue.isEmpty();
} }
/** /**
* Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} * Returns the next new dependent (Y in X-&gt;Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)}
* Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br> * Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br>
* Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; * Note that once a value has been retrieved from here, no new dependencies of the form (X -&gt; Y) can be added for this value;
* the value is considered "processed" at this point. * the value is considered "processed" at this point.
* *
* @return The next new "all satisfied dependent" * @return The next new "all satisfied dependent"

View File

@ -487,7 +487,7 @@ public abstract class AbstractSession<T, O> {
} }
/** /**
* Add the control dependency from Op -> variable * Add the control dependency from Op -&gt; variable
* *
* @param es Execution step for the variable * @param es Execution step for the variable
* @param v Variable * @param v Variable
@ -542,7 +542,7 @@ public abstract class AbstractSession<T, O> {
/** /**
* Update the descendant dependencies * Update the descendant dependencies
* So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker * So if the graph structure is X -&gt; A, then add all (X,Y,Z,...) -&gt; A to the dependency tracker
* This is for a specific frame and iteration, for both sides of the dependency (in and out) * This is for a specific frame and iteration, for both sides of the dependency (in and out)
* *
* @param justExecuted The execution step that has just completed * @param justExecuted The execution step that has just completed
@ -621,7 +621,7 @@ public abstract class AbstractSession<T, O> {
/** /**
* Suppose operation X has just been executed. * Suppose operation X has just been executed.
* For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp * For X -&gt; someOp, add all dependencies for someOp, i.e., all Z -&gt; someOp
* (which includes X, but may not only be X) * (which includes X, but may not only be X)
* *
* @param opName Name of the op * @param opName Name of the op

View File

@ -28,15 +28,15 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
public class SDBitwise extends SDOps { public class SDBitwise extends SDOps {
public SDBitwise(SameDiff sameDiff) { public SDBitwise(SameDiff sameDiff) {
super(sameDiff); super(sameDiff);
} }
/** /**
* Bitwise AND operation. Supports broadcasting.<br> * Bitwise AND operation. Supports broadcasting.<br>
* * <p>
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br> * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* *
* @param x First input array (INT type) * @param x First input array (INT type)
@ -47,147 +47,155 @@ public class SDBitwise extends SDOps {
SDValidation.validateInteger("and", "x", x); SDValidation.validateInteger("and", "x", x);
SDValidation.validateInteger("and", "y", y); SDValidation.validateInteger("and", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd, x, y).outputVariable();
} }
/** /**
* Bitwise AND operation. Supports broadcasting.<br> * Bitwise AND operation. Supports broadcasting.<br>
* * <p>
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br> * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x First input array (INT type) * @param x First input array (INT type)
* @param y Second input array (INT type) * @param y Second input array (INT type)
* @return output Bitwise AND array (INT type) * @return output Bitwise AND array (INT type)
*/ */
public SDVariable and(String name, SDVariable x, SDVariable y) { public SDVariable and(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("and", "x", x); SDValidation.validateInteger("and", "x", x);
SDValidation.validateInteger("and", "y", y); SDValidation.validateInteger("and", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br> * Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}<br>
* *
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitRotl(SDVariable x, SDVariable shift) { public SDVariable bitRotl(SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitRotl", "x", x); SDValidation.validateInteger("bitRotl", "x", x);
SDValidation.validateInteger("bitRotl", "shift", shift); SDValidation.validateInteger("bitRotl", "shift", shift);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
shift).outputVariable();
} }
/** /**
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br> * Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitRotl", "x", x); SDValidation.validateInteger("bitRotl", "x", x);
SDValidation.validateInteger("bitRotl", "shift", shift); SDValidation.validateInteger("bitRotl", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
shift).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br> * Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}<br>
* *
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitRotr(SDVariable x, SDVariable shift) { public SDVariable bitRotr(SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitRotr", "x", x); SDValidation.validateInteger("bitRotr", "x", x);
SDValidation.validateInteger("bitRotr", "shift", shift); SDValidation.validateInteger("bitRotr", "shift", shift);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
shift).outputVariable();
} }
/** /**
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br> * Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitRotr", "x", x); SDValidation.validateInteger("bitRotr", "x", x);
SDValidation.validateInteger("bitRotr", "shift", shift); SDValidation.validateInteger("bitRotr", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
shift).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Shift integer bits to the left, i.e. var << 4<br> * Shift integer bits to the left, i.e. {@code var << 4}<br>
* *
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitShift(SDVariable x, SDVariable shift) { public SDVariable bitShift(SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitShift", "x", x); SDValidation.validateInteger("bitShift", "x", x);
SDValidation.validateInteger("bitShift", "shift", shift); SDValidation.validateInteger("bitShift", "shift", shift);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x,
shift).outputVariable();
} }
/** /**
* Shift integer bits to the left, i.e. var << 4<br> * Shift integer bits to the left, i.e. {@code var << 4}<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { public SDVariable bitShift(String name, SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitShift", "x", x); SDValidation.validateInteger("bitShift", "x", x);
SDValidation.validateInteger("bitShift", "shift", shift); SDValidation.validateInteger("bitShift", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x,
shift).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Shift integer bits to the right, i.e. var >> 4<br> * Shift integer bits to the right, i.e. {@code var >> 4}<br>
* *
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { public SDVariable bitShiftRight(SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitShiftRight", "x", x); SDValidation.validateInteger("bitShiftRight", "x", x);
SDValidation.validateInteger("bitShiftRight", "shift", shift); SDValidation.validateInteger("bitShiftRight", "shift", shift);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x,
shift).outputVariable();
} }
/** /**
* Shift integer bits to the right, i.e. var >> 4<br> * Shift integer bits to the right, i.e. {@code var >> 4}<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input 1 (INT type) * @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type) * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type) * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitShiftRight", "x", x); SDValidation.validateInteger("bitShiftRight", "x", x);
SDValidation.validateInteger("bitShiftRight", "shift", shift); SDValidation.validateInteger("bitShiftRight", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x,
shift).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Bitwise Hamming distance reduction over all elements of both input arrays.<br> * Bitwise Hamming distance reduction over all elements of both input arrays.<br> For example, if
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br> * x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at
* * positions 0 and 1)<br>
* Inputs must satisfy the following constraints: <br> * <p>
* Must be same types: isSameType(x, y)<br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* *
* @param x First input array. (INT type) * @param x First input array. (INT type)
* @param y Second input array. (INT type) * @param y Second input array. (INT type)
@ -197,26 +205,28 @@ public class SDBitwise extends SDOps {
SDValidation.validateInteger("bitsHammingDistance", "x", x); SDValidation.validateInteger("bitsHammingDistance", "x", x);
SDValidation.validateInteger("bitsHammingDistance", "y", y); SDValidation.validateInteger("bitsHammingDistance", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x,
y).outputVariable();
} }
/** /**
* Bitwise Hamming distance reduction over all elements of both input arrays.<br> * Bitwise Hamming distance reduction over all elements of both input arrays.<br> For example, if
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br> * x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at
* * positions 0 and 1)<br>
* Inputs must satisfy the following constraints: <br> * <p>
* Must be same types: isSameType(x, y)<br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x First input array. (INT type) * @param x First input array. (INT type)
* @param y Second input array. (INT type) * @param y Second input array. (INT type)
* @return output bitwise Hamming distance (INT type) * @return output bitwise Hamming distance (INT type)
*/ */
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) { public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("bitsHammingDistance", "x", x); SDValidation.validateInteger("bitsHammingDistance", "x", x);
SDValidation.validateInteger("bitsHammingDistance", "y", y); SDValidation.validateInteger("bitsHammingDistance", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -230,27 +240,28 @@ public class SDBitwise extends SDOps {
public SDVariable leftShift(SDVariable x, SDVariable y) { public SDVariable leftShift(SDVariable x, SDVariable y) {
SDValidation.validateInteger("leftShift", "x", x); SDValidation.validateInteger("leftShift", "x", x);
SDValidation.validateInteger("leftShift", "y", y); SDValidation.validateInteger("leftShift", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, y).outputVariable();
} }
/** /**
* Bitwise left shift operation. Supports broadcasting.<br> * Bitwise left shift operation. Supports broadcasting.<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input to be bit shifted (INT type) * @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type) * @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type) * @return output Bitwise shifted input x (INT type)
*/ */
public SDVariable leftShift(String name, SDVariable x, SDVariable y) { public SDVariable leftShift(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("leftShift", "x", x); SDValidation.validateInteger("leftShift", "x", x);
SDValidation.validateInteger("leftShift", "y", y); SDValidation.validateInteger("leftShift", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Bitwise left cyclical shift operation. Supports broadcasting.<br> * Bitwise left cyclical shift operation. Supports broadcasting.<br> Unlike
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br> * {@link SDBitwise#leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br> * {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
* *
* @param x Input to be bit shifted (INT type) * @param x Input to be bit shifted (INT type)
@ -260,31 +271,32 @@ public class SDBitwise extends SDOps {
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) { public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) {
SDValidation.validateInteger("leftShiftCyclic", "x", x); SDValidation.validateInteger("leftShiftCyclic", "x", x);
SDValidation.validateInteger("leftShiftCyclic", "y", y); SDValidation.validateInteger("leftShiftCyclic", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
y).outputVariable();
} }
/** /**
* Bitwise left cyclical shift operation. Supports broadcasting.<br> * Bitwise left cyclical shift operation. Supports broadcasting.<br> Unlike
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br> * {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br> * {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input to be bit shifted (INT type) * @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type) * @param y Amount to shift elements of x array (INT type)
* @return output Bitwise cyclic shifted input x (INT type) * @return output Bitwise cyclic shifted input x (INT type)
*/ */
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) { public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("leftShiftCyclic", "x", x); SDValidation.validateInteger("leftShiftCyclic", "x", x);
SDValidation.validateInteger("leftShiftCyclic", "y", y); SDValidation.validateInteger("leftShiftCyclic", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Bitwise OR operation. Supports broadcasting.<br> * Bitwise OR operation. Supports broadcasting.<br>
* * <p>
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br> * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* *
* @param x First input array (INT type) * @param x First input array (INT type)
@ -295,26 +307,26 @@ public class SDBitwise extends SDOps {
SDValidation.validateInteger("or", "x", x); SDValidation.validateInteger("or", "x", x);
SDValidation.validateInteger("or", "y", y); SDValidation.validateInteger("or", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd, x, y).outputVariable();
} }
/** /**
* Bitwise OR operation. Supports broadcasting.<br> * Bitwise OR operation. Supports broadcasting.<br>
* * <p>
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br> * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x First input array (INT type) * @param x First input array (INT type)
* @param y First input array (INT type) * @param y First input array (INT type)
* @return output Bitwise OR array (INT type) * @return output Bitwise OR array (INT type)
*/ */
public SDVariable or(String name, SDVariable x, SDVariable y) { public SDVariable or(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("or", "x", x); SDValidation.validateInteger("or", "x", x);
SDValidation.validateInteger("or", "y", y); SDValidation.validateInteger("or", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -328,27 +340,28 @@ public class SDBitwise extends SDOps {
public SDVariable rightShift(SDVariable x, SDVariable y) { public SDVariable rightShift(SDVariable x, SDVariable y) {
SDValidation.validateInteger("rightShift", "x", x); SDValidation.validateInteger("rightShift", "x", x);
SDValidation.validateInteger("rightShift", "y", y); SDValidation.validateInteger("rightShift", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, y).outputVariable();
} }
/** /**
* Bitwise right shift operation. Supports broadcasting. <br> * Bitwise right shift operation. Supports broadcasting. <br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input to be bit shifted (INT type) * @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type) * @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type) * @return output Bitwise shifted input x (INT type)
*/ */
public SDVariable rightShift(String name, SDVariable x, SDVariable y) { public SDVariable rightShift(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("rightShift", "x", x); SDValidation.validateInteger("rightShift", "x", x);
SDValidation.validateInteger("rightShift", "y", y); SDValidation.validateInteger("rightShift", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Bitwise right cyclical shift operation. Supports broadcasting.<br> * Bitwise right cyclical shift operation. Supports broadcasting.<br> Unlike
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br> * {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br> * {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
* *
* @param x Input to be bit shifted (INT type) * @param x Input to be bit shifted (INT type)
@ -358,31 +371,32 @@ public class SDBitwise extends SDOps {
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) { public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) {
SDValidation.validateInteger("rightShiftCyclic", "x", x); SDValidation.validateInteger("rightShiftCyclic", "x", x);
SDValidation.validateInteger("rightShiftCyclic", "y", y); SDValidation.validateInteger("rightShiftCyclic", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
y).outputVariable();
} }
/** /**
* Bitwise right cyclical shift operation. Supports broadcasting.<br> * Bitwise right cyclical shift operation. Supports broadcasting.<br> Unlike
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br> * {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br> * {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x Input to be bit shifted (INT type) * @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type) * @param y Amount to shift elements of x array (INT type)
* @return output Bitwise cyclic shifted input x (INT type) * @return output Bitwise cyclic shifted input x (INT type)
*/ */
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) { public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("rightShiftCyclic", "x", x); SDValidation.validateInteger("rightShiftCyclic", "x", x);
SDValidation.validateInteger("rightShiftCyclic", "y", y); SDValidation.validateInteger("rightShiftCyclic", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/** /**
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br> * Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
* * <p>
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br> * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* *
* @param x First input array (INT type) * @param x First input array (INT type)
@ -393,26 +407,26 @@ public class SDBitwise extends SDOps {
SDValidation.validateInteger("xor", "x", x); SDValidation.validateInteger("xor", "x", x);
SDValidation.validateInteger("xor", "y", y); SDValidation.validateInteger("xor", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd, x, y).outputVariable();
} }
/** /**
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br> * Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
* * <p>
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br> * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param x First input array (INT type) * @param x First input array (INT type)
* @param y First input array (INT type) * @param y First input array (INT type)
* @return output Bitwise XOR array (INT type) * @return output Bitwise XOR array (INT type)
*/ */
public SDVariable xor(String name, SDVariable x, SDVariable y) { public SDVariable xor(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("xor", "x", x); SDValidation.validateInteger("xor", "x", x);
SDValidation.validateInteger("xor", "y", y); SDValidation.validateInteger("xor", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types"); Preconditions.checkArgument(isSameType(x, y), "Must be same types");
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd, x,
y).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
} }

View File

@ -355,7 +355,8 @@ public class SDImage extends SDOps {
* @param maxOutSize scalar representing the maximum number of boxes to be selected * @param maxOutSize scalar representing the maximum number of boxes to be selected
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
* @param scoreThreshold threshold for deciding when to remove boxes based on score * @param scoreThreshold threshold for deciding when to remove boxes based on score
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) * @return output vectort of shape [M] representing the selected indices from the boxes tensor,
* where M &lt;= max_output_size (NUMERIC type)
*/ */
public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize, public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize,
double iouThreshold, double scoreThreshold) { double iouThreshold, double scoreThreshold) {
@ -373,7 +374,8 @@ public class SDImage extends SDOps {
* @param maxOutSize scalar representing the maximum number of boxes to be selected * @param maxOutSize scalar representing the maximum number of boxes to be selected
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
* @param scoreThreshold threshold for deciding when to remove boxes based on score * @param scoreThreshold threshold for deciding when to remove boxes based on score
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) * @return output vectort of shape [M] representing the selected indices from the boxes tensor,
* where M &lt;= max_output_size (NUMERIC type)
*/ */
public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores, public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores,
int maxOutSize, double iouThreshold, double scoreThreshold) { int maxOutSize, double iouThreshold, double scoreThreshold) {

View File

@ -100,7 +100,7 @@ public class SDRandom extends SDOps {
* P(x) = lambda * exp(-lambda * x)<br> * P(x) = lambda * exp(-lambda * x)<br>
* *
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br>
* Must be positive: lambda > 0<br> * Must be positive: lambda &gt; 0<br>
* *
* @param lambda lambda parameter * @param lambda lambda parameter
* @param datatype Data type of the output variable * @param datatype Data type of the output variable
@ -118,7 +118,7 @@ public class SDRandom extends SDOps {
* P(x) = lambda * exp(-lambda * x)<br> * P(x) = lambda * exp(-lambda * x)<br>
* *
* Inputs must satisfy the following constraints: <br> * Inputs must satisfy the following constraints: <br>
* Must be positive: lambda > 0<br> * Must be positive: lambda &gt; 0<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param lambda lambda parameter * @param lambda lambda parameter

View File

@ -829,9 +829,9 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
* Precision based on guesses so far.<br> * Precision based on guesses so far.<br>
* Note: value returned will differ depending on number of classes and settings.<br> * Note: value returned will differ depending on number of classes and settings.<br>
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
* only.<br> * only.<br>
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
* across all classes. i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}<br> * across all classes. i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}<br>
* *
* @return the total precision based on guesses so far * @return the total precision based on guesses so far
@ -977,7 +977,7 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
* only.<br> * only.<br>
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
* across all classes. i.e., is macro-averaged recall, equivalent to {@code recall(EvaluationAveraging.Macro)}<br> * across all classes. i.e., is macro-averaged recall, equivalent to {@code recall(EvaluationAveraging.Macro)}<br>
* *
* @return the recall for the outcomes * @return the recall for the outcomes
@ -1173,12 +1173,12 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
/** /**
* False Alarm Rate (FAR) reflects rate of misclassified to classified records * False Alarm Rate (FAR) reflects rate of misclassified to classified records
* {@link }http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}<br> * {@see http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}<br>
* Note: value returned will differ depending on number of classes and settings.<br> * Note: value returned will differ depending on number of classes and settings.<br>
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
* only.<br> * only.<br>
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
* across all classes. i.e., is macro-averaged false alarm rate) * across all classes. i.e., is macro-averaged false alarm rate)
* *
* @return the fpr for the outcomes * @return the fpr for the outcomes
@ -1243,9 +1243,9 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
* <br> * <br>
* Note: value returned will differ depending on number of classes and settings.<br> * Note: value returned will differ depending on number of classes and settings.<br>
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
* only.<br> * only.<br>
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
* across all classes. i.e., is macro-averaged f1, equivalent to {@code f1(EvaluationAveraging.Macro)}<br> * across all classes. i.e., is macro-averaged f1, equivalent to {@code f1(EvaluationAveraging.Macro)}<br>
* *
* @return the f1 score or harmonic mean of precision and recall based on current guesses * @return the f1 score or harmonic mean of precision and recall based on current guesses

View File

@ -584,7 +584,7 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
/** /**
* False Alarm Rate (FAR) reflects rate of misclassified to classified records * False Alarm Rate (FAR) reflects rate of misclassified to classified records
* <a href="http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw">http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw</a><br> * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&amp;context=isw<br>
* *
* @param outputNum Class index to calculate False Alarm Rate (FAR) * @param outputNum Class index to calculate False Alarm Rate (FAR)
* @return The FAR for the outcomes * @return The FAR for the outcomes
@ -611,7 +611,7 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
//Report: Accuracy, precision, recall, F1. Then: confusion matrix //Report: Accuracy, precision, recall, F1. Then: confusion matrix]
int maxLabelsLength = 15; int maxLabelsLength = 15;
if (labels != null) { if (labels != null) {

View File

@ -202,7 +202,7 @@ public abstract class BaseLapack implements Lapack {
* *
* @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors * @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors
* @param uplo upper or lower part of symmetric matrix to use * @param uplo upper or lower part of symmetric matrix to use
* @param N the number of rows & cols in the matrix A * @param N the number of rows &amp; cols in the matrix A
* @param A the matrix to calculate eigenvectors * @param A the matrix to calculate eigenvectors
* @param R an output array for eigenvalues ( may be null ) * @param R an output array for eigenvalues ( may be null )
*/ */

View File

@ -74,14 +74,14 @@ public interface DataBufferFactory {
DataBuffer createDouble(long offset, int length); DataBuffer createDouble(long offset, int length);
/** /**
* This method will create new DataBuffer of the same dataType & same length * This method will create new DataBuffer of the same dataType &amp; same length
* @param buffer * @param buffer
* @return * @return
*/ */
DataBuffer createSame(DataBuffer buffer, boolean init); DataBuffer createSame(DataBuffer buffer, boolean init);
/** /**
* This method will create new DataBuffer of the same dataType & same length * This method will create new DataBuffer of the same dataType &amp; same length
* @param buffer * @param buffer
* @return * @return
*/ */

View File

@ -132,7 +132,6 @@ public interface MemoryManager {
* *
* @param pointer * @param pointer
* @param kind * @param kind
* @return
*/ */
void release(Pointer pointer, MemoryKind kind); void release(Pointer pointer, MemoryKind kind);

View File

@ -137,7 +137,7 @@ public interface MemoryWorkspaceManager {
void destroyWorkspace(MemoryWorkspace workspace); void destroyWorkspace(MemoryWorkspace workspace);
/** /**
* This method destroys & deallocates all Workspaces for a calling Thread * This method destroys &amp; deallocates all Workspaces for a calling Thread
* *
* PLEASE NOTE: This method is NOT safe * PLEASE NOTE: This method is NOT safe
*/ */
@ -149,21 +149,21 @@ public interface MemoryWorkspaceManager {
void destroyWorkspace(); void destroyWorkspace();
/** /**
* This method gets & activates default workspace * This method gets and activates default workspace
* *
* @return * @return
*/ */
MemoryWorkspace getAndActivateWorkspace(); MemoryWorkspace getAndActivateWorkspace();
/** /**
* This method gets & activates workspace with a given Id * This method gets and activates workspace with a given Id
* *
* @return * @return
*/ */
MemoryWorkspace getAndActivateWorkspace(String id); MemoryWorkspace getAndActivateWorkspace(String id);
/** /**
* This method gets & activates default with a given configuration and Id * This method gets and activates default with a given configuration and Id
* *
* @return * @return
*/ */

View File

@ -47,7 +47,7 @@ public abstract class BaseShapeInfoProvider implements ShapeInfoProvider {
} }
/** /**
* This method creates shapeInformation buffer, based on shape & order being passed in * This method creates shapeInformation buffer, based on shape and order being passed in
* *
* @param shape * @param shape
* @param order * @param order

View File

@ -2216,7 +2216,7 @@ public interface INDArray extends Serializable, AutoCloseable {
* Dimshuffle: an extension of permute that adds the ability * Dimshuffle: an extension of permute that adds the ability
* to broadcast various dimensions. * to broadcast various dimensions.
* This will only accept integers and xs. * This will only accept integers and xs.
* <p/> * <p>
* An x indicates a dimension should be broadcasted rather than permuted. * An x indicates a dimension should be broadcasted rather than permuted.
* *
* Examples originally from the theano docs: * Examples originally from the theano docs:
@ -2226,15 +2226,15 @@ public interface INDArray extends Serializable, AutoCloseable {
A few examples of patterns and their effect: A few examples of patterns and their effect:
('x') -> make a 0d (scalar) into a 1d vector ('x') -&gt; make a 0d (scalar) into a 1d vector
(0, 1) -> identity for 2d vectors (0, 1) -&gt; identity for 2d vectors
(1, 0) -> inverts the first and second dimensions (1, 0) -&gt; inverts the first and second dimensions
('x', 0) -> make a row out of a 1d vector (N to 1xN) ('x', 0) -&gt; make a row out of a 1d vector (N to 1xN)
(0, 'x') -> make a column out of a 1d vector (N to Nx1) (0, 'x') -&gt; make a column out of a 1d vector (N to Nx1)
(2, 0, 1) -> AxBxC to CxAxB (2, 0, 1) -&gt; AxBxC to CxAxB
(0, 'x', 1) -> AxB to Ax1xB (0, 'x', 1) -&gt; AxB to Ax1xB
(1, 'x', 0) -> AxB to Bx1xA (1, 'x', 0) -&gt; AxB to Bx1xA
(1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) (1,) -&gt; This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
* @param rearrange the dimensions to swap to * @param rearrange the dimensions to swap to
* @param newOrder the new order (think permute) * @param newOrder the new order (think permute)
@ -2244,7 +2244,7 @@ public interface INDArray extends Serializable, AutoCloseable {
INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable); INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable);
/** /**
* See {@link #dimShuffle(Object[], int[], boolean[]) * See {@link #dimShuffle(Object[], int[], boolean[])}
*/ */
INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable); INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable);

View File

@ -33,7 +33,7 @@ public interface ShapeInfoProvider {
Pair<DataBuffer, long[]> createShapeInformation(long[] shape, DataType dataType); Pair<DataBuffer, long[]> createShapeInformation(long[] shape, DataType dataType);
/** /**
* This method creates long shapeInformation buffer, based on shape & order being passed in * This method creates long shapeInformation buffer, based on shape and order being passed in
* @param shape * @param shape
* @return * @return
*/ */

View File

@ -65,7 +65,8 @@ public interface OpContext extends AutoCloseable {
/** /**
* This method sets root-level seed for rng * This method sets root-level seed for rng
* @param seed * @param rootState
* @param nodeState
*/ */
void setRngStates(long rootState, long nodeState); void setRngStates(long rootState, long nodeState);

View File

@ -251,7 +251,7 @@ public interface Random extends AutoCloseable {
* The reason for this is due to ints * The reason for this is due to ints
* having the same space usage as floats. * having the same space usage as floats.
* This also plays nice with blas. * This also plays nice with blas.
* <p/> * <p>
* If the data opType is set to double, * If the data opType is set to double,
* then these will be whole doubles. * then these will be whole doubles.
* *
@ -272,7 +272,7 @@ public interface Random extends AutoCloseable {
* The reason for this is due to ints * The reason for this is due to ints
* having the same space usage as floats. * having the same space usage as floats.
* This also plays nice with blas. * This also plays nice with blas.
* <p/> * <p>
* If the data opType is set to double, * If the data opType is set to double,
* then these will be whole doubles. * then these will be whole doubles.
* *

View File

@ -35,233 +35,236 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.Iterator; import java.util.Iterator;
public abstract class BaseDistribution implements Distribution { public abstract class BaseDistribution implements Distribution {
protected Random random;
protected double solverAbsoluteAccuracy; protected Random random;
protected double solverAbsoluteAccuracy;
public BaseDistribution(Random rng) { public BaseDistribution(Random rng) {
this.random = rng; this.random = rng;
}
public BaseDistribution() {
this(Nd4j.getRandom());
}
/**
* For a random variable {@code X} whose values are distributed according to this distribution,
* this method returns {@code P(x0 < X <= x1)}.
*
* @param x0 Lower bound (excluded).
* @param x1 Upper bound (included).
* @return the probability that a random variable with this distribution takes a value between
* {@code x0} and {@code x1}, excluding the lower and including the upper endpoint.
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}.
* <p>
* The default implementation
* uses the identity
* {@code P(x0 < X <= x1) =
* P(X <= x1) - P(X <= x0)}
* @since 3.1
*/
public double probability(double x0, double x1) {
if (x0 > x1) {
throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0,
x1, true);
} }
return cumulativeProbability(x1) - cumulativeProbability(x0);
}
/**
public BaseDistribution() { * {@inheritDoc}
this(Nd4j.getRandom()); * <p>
} * The default implementation returns
* <ul>
/** * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
* For a random variable {@code X} whose values are distributed according * <li>{@link #getSupportUpperBound()} for {@code p = 1}.</li>
* to this distribution, this method returns {@code P(x0 < X <= x1)}. * </ul>
*/
@Override
public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
/*
* IMPLEMENTATION NOTES
* --------------------
* Where applicable, use is made of the one-sided Chebyshev inequality
* to bracket the root. This inequality states that
* P(X - mu >= k * sig) <= 1 / (1 + k^2),
* mu: mean, sig: standard deviation. Equivalently
* 1 - P(X < mu + k * sig) <= 1 / (1 + k^2),
* F(mu + k * sig) >= k^2 / (1 + k^2).
* *
* @param x0 Lower bound (excluded). * For k = sqrt(p / (1 - p)), we find
* @param x1 Upper bound (included). * F(mu + k * sig) >= p,
* @return the probability that a random variable with this distribution * and (mu + k * sig) is an upper-bound for the root.
* takes a value between {@code x0} and {@code x1}, excluding the lower *
* and including the upper endpoint. * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}. * P(Y >= -mu + k * sig) <= 1 / (1 + k^2),
* <p/> * P(-X >= -mu + k * sig) <= 1 / (1 + k^2),
* The default implementation uses the identity * P(X <= mu - k * sig) <= 1 / (1 + k^2),
* {@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)} * F(mu - k * sig) <= 1 / (1 + k^2).
* @since 3.1 *
* For k = sqrt((1 - p) / p), we find
* F(mu - k * sig) <= p,
* and (mu - k * sig) is a lower-bound for the root.
*
* In cases where the Chebyshev inequality does not apply, geometric
* progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket
* the root.
*/ */
if (p < 0.0 || p > 1.0) {
public double probability(double x0, double x1) { throw new OutOfRangeException(p, 0, 1);
if (x0 > x1) {
throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, x1, true);
}
return cumulativeProbability(x1) - cumulativeProbability(x0);
} }
/** double lowerBound = getSupportLowerBound();
* {@inheritDoc} if (p == 0.0) {
* <p/> return lowerBound;
* The default implementation returns }
* <ul>
* <li>{@link #getSupportLowerBound()} for {@code p = 0},</li> double upperBound = getSupportUpperBound();
* <li>{@link #getSupportUpperBound()} for {@code p = 1}.</li> if (p == 1.0) {
* </ul> return upperBound;
*/ }
@Override
public double inverseCumulativeProbability(final double p) throws OutOfRangeException { final double mu = getNumericalMean();
/* final double sig = FastMath.sqrt(getNumericalVariance());
* IMPLEMENTATION NOTES final boolean chebyshevApplies;
* -------------------- chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig)
* Where applicable, use is made of the one-sided Chebyshev inequality || Double.isNaN(sig));
* to bracket the root. This inequality states that
* P(X - mu >= k * sig) <= 1 / (1 + k^2), if (lowerBound == Double.NEGATIVE_INFINITY) {
* mu: mean, sig: standard deviation. Equivalently if (chebyshevApplies) {
* 1 - P(X < mu + k * sig) <= 1 / (1 + k^2), lowerBound = mu - sig * FastMath.sqrt((1. - p) / p);
* F(mu + k * sig) >= k^2 / (1 + k^2). } else {
* lowerBound = -1.0;
* For k = sqrt(p / (1 - p)), we find while (cumulativeProbability(lowerBound) >= p) {
* F(mu + k * sig) >= p, lowerBound *= 2.0;
* and (mu + k * sig) is an upper-bound for the root.
*
* Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and
* P(Y >= -mu + k * sig) <= 1 / (1 + k^2),
* P(-X >= -mu + k * sig) <= 1 / (1 + k^2),
* P(X <= mu - k * sig) <= 1 / (1 + k^2),
* F(mu - k * sig) <= 1 / (1 + k^2).
*
* For k = sqrt((1 - p) / p), we find
* F(mu - k * sig) <= p,
* and (mu - k * sig) is a lower-bound for the root.
*
* In cases where the Chebyshev inequality does not apply, geometric
* progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket
* the root.
*/
if (p < 0.0 || p > 1.0) {
throw new OutOfRangeException(p, 0, 1);
} }
}
}
double lowerBound = getSupportLowerBound(); if (upperBound == Double.POSITIVE_INFINITY) {
if (p == 0.0) { if (chebyshevApplies) {
return lowerBound; upperBound = mu + sig * FastMath.sqrt(p / (1. - p));
} else {
upperBound = 1.0;
while (cumulativeProbability(upperBound) < p) {
upperBound *= 2.0;
} }
}
}
double upperBound = getSupportUpperBound(); final UnivariateFunction toSolve = new UnivariateFunction() {
if (p == 1.0) {
return upperBound;
}
final double mu = getNumericalMean(); public double value(final double x) {
final double sig = FastMath.sqrt(getNumericalVariance()); return cumulativeProbability(x) - p;
final boolean chebyshevApplies; }
chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig) || Double.isNaN(sig)); };
if (lowerBound == Double.NEGATIVE_INFINITY) { double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound,
if (chebyshevApplies) { getSolverAbsoluteAccuracy());
lowerBound = mu - sig * FastMath.sqrt((1. - p) / p);
if (!isSupportConnected()) {
/* Test for plateau. */
final double dx = getSolverAbsoluteAccuracy();
if (x - dx >= getSupportLowerBound()) {
double px = cumulativeProbability(x);
if (cumulativeProbability(x - dx) == px) {
upperBound = x;
while (upperBound - lowerBound > dx) {
final double midPoint = 0.5 * (lowerBound + upperBound);
if (cumulativeProbability(midPoint) < px) {
lowerBound = midPoint;
} else { } else {
lowerBound = -1.0; upperBound = midPoint;
while (cumulativeProbability(lowerBound) >= p) {
lowerBound *= 2.0;
}
} }
}
return upperBound;
} }
}
if (upperBound == Double.POSITIVE_INFINITY) {
if (chebyshevApplies) {
upperBound = mu + sig * FastMath.sqrt(p / (1. - p));
} else {
upperBound = 1.0;
while (cumulativeProbability(upperBound) < p) {
upperBound *= 2.0;
}
}
}
final UnivariateFunction toSolve = new UnivariateFunction() {
public double value(final double x) {
return cumulativeProbability(x) - p;
}
};
double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound, getSolverAbsoluteAccuracy());
if (!isSupportConnected()) {
/* Test for plateau. */
final double dx = getSolverAbsoluteAccuracy();
if (x - dx >= getSupportLowerBound()) {
double px = cumulativeProbability(x);
if (cumulativeProbability(x - dx) == px) {
upperBound = x;
while (upperBound - lowerBound > dx) {
final double midPoint = 0.5 * (lowerBound + upperBound);
if (cumulativeProbability(midPoint) < px) {
lowerBound = midPoint;
} else {
upperBound = midPoint;
}
}
return upperBound;
}
}
}
return x;
} }
return x;
}
/** /**
* Returns the solver absolute accuracy for inverse cumulative computation. * Returns the solver absolute accuracy for inverse cumulative computation. You can override this
* You can override this method in order to use a Brent solver with an * method in order to use a Brent solver with an absolute accuracy different from the default.
* absolute accuracy different from the default. *
* * @return the maximum absolute error in inverse cumulative probability estimates
* @return the maximum absolute error in inverse cumulative probability estimates */
*/ protected double getSolverAbsoluteAccuracy() {
protected double getSolverAbsoluteAccuracy() { return solverAbsoluteAccuracy;
return solverAbsoluteAccuracy; }
}
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public void reseedRandomGenerator(long seed) { public void reseedRandomGenerator(long seed) {
random.setSeed(seed); random.setSeed(seed);
} }
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * The default implementation uses the
* The default implementation uses the * <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling">
* <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling"> * inversion method.
* inversion method. * </a>
* </a> */
*/ @Override
@Override public double sample() {
public double sample() { return inverseCumulativeProbability(random.nextDouble());
return inverseCumulativeProbability(random.nextDouble()); }
}
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The default implementation generates the sample by calling * The default implementation generates the sample by calling {@link #sample()} in a loop.
* {@link #sample()} in a loop. */
*/ @Override
@Override public double[] sample(long sampleSize) {
public double[] sample(long sampleSize) { if (sampleSize <= 0) {
if (sampleSize <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
}
double[] out = new double[(int) sampleSize];
for (int i = 0; i < sampleSize; i++) {
out[i] = sample();
}
return out;
} }
double[] out = new double[(int) sampleSize];
for (int i = 0; i < sampleSize; i++) {
out[i] = sample();
}
return out;
}
/** /**
* {@inheritDoc} * {@inheritDoc}
* *
* @return zero. * @return zero.
* @since 3.1 * @since 3.1
*/ */
@Override @Override
public double probability(double x) { public double probability(double x) {
return 0d; return 0d;
} }
@Override @Override
public INDArray sample(int[] shape) { public INDArray sample(int[] shape) {
INDArray ret = Nd4j.create(shape); INDArray ret = Nd4j.create(shape);
return sample(ret); return sample(ret);
} }
@Override @Override
public INDArray sample(long[] shape) { public INDArray sample(long[] shape) {
INDArray ret = Nd4j.create(shape); INDArray ret = Nd4j.create(shape);
return sample(ret); return sample(ret);
} }
@Override @Override
public INDArray sample(INDArray target) { public INDArray sample(INDArray target) {
Iterator<long[]> idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering Iterator<long[]> idxIter = new NdIndexIterator(
long len = target.length(); target.shape()); //For consistent values irrespective of c vs. fortran ordering
for (long i = 0; i < len; i++) { long len = target.length();
target.putScalar(idxIter.next(), sample()); for (long i = 0; i < len; i++) {
} target.putScalar(idxIter.next(), sample());
return target;
} }
return target;
}
} }

View File

@ -89,8 +89,8 @@ public interface Distribution {
* variable {@code X} distributed according to this distribution, the * variable {@code X} distributed according to this distribution, the
* returned value is * returned value is
* <ul> * <ul>
* <li><code>inf{x in R | P(X<=x) >= p}</code> for {@code 0 < p <= 1},</li> * <li>{@code inf{x in R | P(X<=x) >= p}} for {@code 0 < p <= 1},</li>
* <li><code>inf{x in R | P(X<=x) > 0}</code> for {@code p = 0}.</li> * <li>{@code inf{x in R | P(X<=x) > 0}} for {@code p = 0}.</li>
* </ul> * </ul>
* *
* @param p the cumulative probability * @param p the cumulative probability
@ -122,7 +122,7 @@ public interface Distribution {
* Access the lower bound of the support. This method must return the same * Access the lower bound of the support. This method must return the same
* value as {@code inverseCumulativeProbability(0)}. In other words, this * value as {@code inverseCumulativeProbability(0)}. In other words, this
* method must return * method must return
* <p><code>inf {x in R | P(X <= x) > 0}</code>.</p> * <p>{@code inf {x in R | P(X <= x) > 0}}.</p>
* *
* @return lower bound of the support (might be * @return lower bound of the support (might be
* {@code Double.NEGATIVE_INFINITY}) * {@code Double.NEGATIVE_INFINITY})
@ -133,7 +133,7 @@ public interface Distribution {
* Access the upper bound of the support. This method must return the same * Access the upper bound of the support. This method must return the same
* value as {@code inverseCumulativeProbability(1)}. In other words, this * value as {@code inverseCumulativeProbability(1)}. In other words, this
* method must return * method must return
* <p><code>inf {x in R | P(X <= x) = 1}</code>.</p> * <p>{@code inf {x in R | P(X <= x) = 1}}.</p>
* *
* @return upper bound of the support (might be * @return upper bound of the support (might be
* {@code Double.POSITIVE_INFINITY}) * {@code Double.POSITIVE_INFINITY})

View File

@ -166,7 +166,7 @@ public class BinomialDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For {@code n} trials and probability parameter {@code p}, the mean is * For {@code n} trials and probability parameter {@code p}, the mean is
* {@code n * p}. * {@code n * p}.
*/ */
@ -177,7 +177,7 @@ public class BinomialDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For {@code n} trials and probability parameter {@code p}, the variance is * For {@code n} trials and probability parameter {@code p}, the variance is
* {@code n * p * (1 - p)}. * {@code n * p * (1 - p)}.
*/ */
@ -189,7 +189,7 @@ public class BinomialDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The lower bound of the support is always 0 except for the probability * The lower bound of the support is always 0 except for the probability
* parameter {@code p = 1}. * parameter {@code p = 1}.
* *
@ -203,7 +203,7 @@ public class BinomialDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The upper bound of the support is the number of trials except for the * The upper bound of the support is the number of trials except for the
* probability parameter {@code p = 0}. * probability parameter {@code p = 0}.
* *
@ -227,7 +227,7 @@ public class BinomialDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The support of this distribution is connected. * The support of this distribution is connected.
* *
* @return {@code true} * @return {@code true}

View File

@ -83,7 +83,7 @@ public class ConstantDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * If {@code x} is more than 40 standard deviations from the mean, 0 or 1
* is returned, as in these cases the actual value is within * is returned, as in these cases the actual value is within
* {@code Double.MIN_VALUE} of 0 or 1. * {@code Double.MIN_VALUE} of 0 or 1.
@ -131,7 +131,7 @@ public class ConstantDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For mean parameter {@code mu}, the mean is {@code mu}. * For mean parameter {@code mu}, the mean is {@code mu}.
*/ */
public double getNumericalMean() { public double getNumericalMean() {
@ -140,7 +140,7 @@ public class ConstantDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For standard deviation parameter {@code s}, the variance is {@code s^2}. * For standard deviation parameter {@code s}, the variance is {@code s^2}.
*/ */
public double getNumericalVariance() { public double getNumericalVariance() {
@ -150,7 +150,7 @@ public class ConstantDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The lower bound of the support is always negative infinity * The lower bound of the support is always negative infinity
* no matter the parameters. * no matter the parameters.
* *
@ -163,7 +163,7 @@ public class ConstantDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The upper bound of the support is always positive infinity * The upper bound of the support is always positive infinity
* no matter the parameters. * no matter the parameters.
* *
@ -190,7 +190,7 @@ public class ConstantDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The support of this distribution is connected. * The support of this distribution is connected.
* *
* @return {@code true} * @return {@code true}

View File

@ -172,7 +172,6 @@ public class LogNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * If {@code x} is more than 40 standard deviations from the mean, 0 or 1
* is returned, as in these cases the actual value is within * is returned, as in these cases the actual value is within
* {@code Double.MIN_VALUE} of 0 or 1. * {@code Double.MIN_VALUE} of 0 or 1.
@ -238,7 +237,6 @@ public class LogNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* For mean parameter {@code mu}, the mean is {@code mu}. * For mean parameter {@code mu}, the mean is {@code mu}.
*/ */
public double getNumericalMean() { public double getNumericalMean() {
@ -247,7 +245,6 @@ public class LogNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* For standard deviation parameter {@code s}, the variance is {@code s^2}. * For standard deviation parameter {@code s}, the variance is {@code s^2}.
*/ */
public double getNumericalVariance() { public double getNumericalVariance() {
@ -257,7 +254,6 @@ public class LogNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* The lower bound of the support is always negative infinity * The lower bound of the support is always negative infinity
* no matter the parameters. * no matter the parameters.
* *
@ -270,7 +266,7 @@ public class LogNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The upper bound of the support is always positive infinity * The upper bound of the support is always positive infinity
* no matter the parameters. * no matter the parameters.
* *
@ -297,7 +293,7 @@ public class LogNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The support of this distribution is connected. * The support of this distribution is connected.
* *
* @return {@code true} * @return {@code true}

View File

@ -176,7 +176,6 @@ public class NormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * If {@code x} is more than 40 standard deviations from the mean, 0 or 1
* is returned, as in these cases the actual value is within * is returned, as in these cases the actual value is within
* {@code Double.MIN_VALUE} of 0 or 1. * {@code Double.MIN_VALUE} of 0 or 1.
@ -242,7 +241,6 @@ public class NormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* For mean parameter {@code mu}, the mean is {@code mu}. * For mean parameter {@code mu}, the mean is {@code mu}.
*/ */
public double getNumericalMean() { public double getNumericalMean() {
@ -251,7 +249,6 @@ public class NormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* For standard deviation parameter {@code s}, the variance is {@code s^2}. * For standard deviation parameter {@code s}, the variance is {@code s^2}.
*/ */
public double getNumericalVariance() { public double getNumericalVariance() {
@ -261,7 +258,6 @@ public class NormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* The lower bound of the support is always negative infinity * The lower bound of the support is always negative infinity
* no matter the parameters. * no matter the parameters.
* *
@ -274,7 +270,6 @@ public class NormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* The upper bound of the support is always positive infinity * The upper bound of the support is always positive infinity
* no matter the parameters. * no matter the parameters.
* *
@ -301,7 +296,6 @@ public class NormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/>
* The support of this distribution is connected. * The support of this distribution is connected.
* *
* @return {@code true} * @return {@code true}

View File

@ -34,27 +34,28 @@ import org.nd4j.common.util.ArrayUtil;
@Slf4j @Slf4j
public class OrthogonalDistribution extends BaseDistribution { public class OrthogonalDistribution extends BaseDistribution {
/**
* Default inverse cumulative probability accuracy.
*
* @since 2.1
*/
public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9;
/**
* Serializable version identifier.
*/
private static final long serialVersionUID = 8589540077390120676L;
/** /**
* Mean of this distribution. * Default inverse cumulative probability accuracy.
*/ *
private final double gain; * @since 2.1
private INDArray gains; */
public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9;
/**
* Serializable version identifier.
*/
private static final long serialVersionUID = 8589540077390120676L;
public OrthogonalDistribution(double gain) { /**
this.gain = gain; * Mean of this distribution.
this.random = Nd4j.getRandom(); */
} private final double gain;
private INDArray gains;
public OrthogonalDistribution(double gain) {
this.gain = gain;
this.random = Nd4j.getRandom();
}
/* /*
max doesn't want this distripution max doesn't want this distripution
public OrthogonalDistribution(@NonNull INDArray gains) { public OrthogonalDistribution(@NonNull INDArray gains) {
@ -62,196 +63,192 @@ public class OrthogonalDistribution extends BaseDistribution {
this.random = Nd4j.getRandom(); this.random = Nd4j.getRandom();
} }
*/ */
/**
* Access the mean. /**
* * Access the mean.
* @return the mean for this distribution. *
*/ * @return the mean for this distribution.
public double getMean() { */
throw new UnsupportedOperationException(); public double getMean() {
throw new UnsupportedOperationException();
}
/**
* Access the standard deviation.
*
* @return the standard deviation for this distribution.
*/
public double getStandardDeviation() {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
public double density(double x) {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc} If {@code x} is more than 40 standard deviations from the mean, 0 or 1 is
* returned, as in these cases the actual value is within {@code Double.MIN_VALUE} of 0 or 1.
*/
public double cumulativeProbability(double x) {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* @since 3.2
*/
@Override
public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* @deprecated See
* {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double,
* double)}
*/
@Override
@Deprecated
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
@Override
public double probability(double x0, double x1) throws NumberIsTooLargeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
@Override
protected double getSolverAbsoluteAccuracy() {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc} For mean parameter {@code mu}, the mean is {@code mu}.
*/
public double getNumericalMean() {
return getMean();
}
/**
* {@inheritDoc} For standard deviation parameter {@code s}, the variance is {@code s^2}.
*/
public double getNumericalVariance() {
final double s = getStandardDeviation();
return s * s;
}
/**
* {@inheritDoc} The lower bound of the support is always negative infinity no matter the
* parameters.
*
* @return lower bound of the support (always {@code Double.NEGATIVE_INFINITY})
*/
public double getSupportLowerBound() {
return Double.NEGATIVE_INFINITY;
}
/**
* {@inheritDoc}
* <p>
* The upper bound of the support is always positive infinity no matter the parameters.
*
* @return upper bound of the support (always {@code Double.POSITIVE_INFINITY})
*/
public double getSupportUpperBound() {
return Double.POSITIVE_INFINITY;
}
/**
* {@inheritDoc}
*/
public boolean isSupportLowerBoundInclusive() {
return false;
}
/**
* {@inheritDoc}
*/
public boolean isSupportUpperBoundInclusive() {
return false;
}
/**
* {@inheritDoc}
* <p>
* The support of this distribution is connected.
*
* @return {@code true}
*/
public boolean isSupportConnected() {
return true;
}
/**
* {@inheritDoc}
*/
@Override
public double sample() {
throw new UnsupportedOperationException();
}
@Override
public INDArray sample(int[] shape) {
return sample(ArrayUtil.toLongArray(shape));
}
@Override
public INDArray sample(long[] shape) {
long numRows = 1;
for (int i = 0; i < shape.length - 1; i++) {
numRows *= shape[i];
} }
long numCols = shape[shape.length - 1];
/** val dtype = Nd4j.defaultFloatingPointType();
* Access the standard deviation.
* val flatShape = new long[]{numRows, numCols};
* @return the standard deviation for this distribution. val flatRng = Nd4j.getExecutioner().exec(
*/ new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0,
public double getStandardDeviation() { 1.0), random);
throw new UnsupportedOperationException();
val m = flatRng.rows();
val n = flatRng.columns();
val s = Nd4j.create(dtype, m < n ? m : n);
val u = Nd4j.create(dtype, m, m);
val v = Nd4j.create(dtype, new long[]{n, n}, 'f');
Nd4j.exec(new Svd(flatRng, true, s, u, v));
if (gains == null) {
if (u.rows() >= numRows && u.columns() >= numCols) {
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain)
.reshape(shape);
} else {
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain)
.reshape(shape);
}
} else {
throw new UnsupportedOperationException();
} }
}
/** @Override
* {@inheritDoc} public INDArray sample(INDArray target) {
*/ return target.assign(sample(target.shape()));
public double density(double x) { }
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
* <p/>
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
* is returned, as in these cases the actual value is within
* {@code Double.MIN_VALUE} of 0 or 1.
*/
public double cumulativeProbability(double x) {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* @since 3.2
*/
@Override
public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*
* @deprecated See {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double, double)}
*/
@Override
@Deprecated
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
@Override
public double probability(double x0, double x1) throws NumberIsTooLargeException {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
*/
@Override
protected double getSolverAbsoluteAccuracy() {
throw new UnsupportedOperationException();
}
/**
* {@inheritDoc}
* <p/>
* For mean parameter {@code mu}, the mean is {@code mu}.
*/
public double getNumericalMean() {
return getMean();
}
/**
* {@inheritDoc}
* <p/>
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
*/
public double getNumericalVariance() {
final double s = getStandardDeviation();
return s * s;
}
/**
* {@inheritDoc}
* <p/>
* The lower bound of the support is always negative infinity
* no matter the parameters.
*
* @return lower bound of the support (always
* {@code Double.NEGATIVE_INFINITY})
*/
public double getSupportLowerBound() {
return Double.NEGATIVE_INFINITY;
}
/**
* {@inheritDoc}
* <p/>
* The upper bound of the support is always positive infinity
* no matter the parameters.
*
* @return upper bound of the support (always
* {@code Double.POSITIVE_INFINITY})
*/
public double getSupportUpperBound() {
return Double.POSITIVE_INFINITY;
}
/**
* {@inheritDoc}
*/
public boolean isSupportLowerBoundInclusive() {
return false;
}
/**
* {@inheritDoc}
*/
public boolean isSupportUpperBoundInclusive() {
return false;
}
/**
* {@inheritDoc}
* <p/>
* The support of this distribution is connected.
*
* @return {@code true}
*/
public boolean isSupportConnected() {
return true;
}
/**
* {@inheritDoc}
*/
@Override
public double sample() {
throw new UnsupportedOperationException();
}
@Override
public INDArray sample(int[] shape) {
return sample(ArrayUtil.toLongArray(shape));
}
@Override
public INDArray sample(long[] shape){
long numRows = 1;
for (int i = 0; i < shape.length - 1; i++)
numRows *= shape[i];
long numCols = shape[shape.length - 1];
val dtype = Nd4j.defaultFloatingPointType();
val flatShape = new long[]{numRows, numCols};
val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random);
val m = flatRng.rows();
val n = flatRng.columns();
val s = Nd4j.create(dtype, m < n ? m : n);
val u = Nd4j.create(dtype, m, m);
val v = Nd4j.create(dtype, new long[] {n, n}, 'f');
Nd4j.exec(new Svd(flatRng, true, s, u, v));
if (gains == null) {
if (u.rows() >= numRows && u.columns() >= numCols) {
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
} else {
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
}
} else {
throw new UnsupportedOperationException();
}
}
@Override
public INDArray sample(INDArray target){
return target.assign(sample(target.shape()));
}
} }

View File

@ -84,7 +84,6 @@ public class SaddlePointExpansion {
* href="http://mathworld.wolfram.com/StirlingsSeries.html"> * href="http://mathworld.wolfram.com/StirlingsSeries.html">
* http://mathworld.wolfram.com/StirlingsSeries.html</a></li> * http://mathworld.wolfram.com/StirlingsSeries.html</a></li>
* </ol> * </ol>
* </p>
* *
* @param z the value. * @param z the value.
* @return the Striling's series error. * @return the Striling's series error.
@ -117,7 +116,6 @@ public class SaddlePointExpansion {
* href="http://www.herine.net/stat/papers/dbinom.pdf"> * href="http://www.herine.net/stat/papers/dbinom.pdf">
* http://www.herine.net/stat/papers/dbinom.pdf</a></li> * http://www.herine.net/stat/papers/dbinom.pdf</a></li>
* </ol> * </ol>
* </p>
* *
* @param x the x value. * @param x the x value.
* @param mu the average. * @param mu the average.

View File

@ -172,7 +172,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * If {@code x} is more than 40 standard deviations from the mean, 0 or 1
* is returned, as in these cases the actual value is within * is returned, as in these cases the actual value is within
* {@code Double.MIN_VALUE} of 0 or 1. * {@code Double.MIN_VALUE} of 0 or 1.
@ -238,7 +238,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For mean parameter {@code mu}, the mean is {@code mu}. * For mean parameter {@code mu}, the mean is {@code mu}.
*/ */
public double getNumericalMean() { public double getNumericalMean() {
@ -247,7 +247,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For standard deviation parameter {@code s}, the variance is {@code s^2}. * For standard deviation parameter {@code s}, the variance is {@code s^2}.
*/ */
public double getNumericalVariance() { public double getNumericalVariance() {
@ -257,7 +257,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The lower bound of the support is always negative infinity * The lower bound of the support is always negative infinity
* no matter the parameters. * no matter the parameters.
* *
@ -270,7 +270,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The upper bound of the support is always positive infinity * The upper bound of the support is always positive infinity
* no matter the parameters. * no matter the parameters.
* *
@ -297,7 +297,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The support of this distribution is connected. * The support of this distribution is connected.
* *
* @return {@code true} * @return {@code true}

View File

@ -105,7 +105,7 @@ public class UniformDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For lower bound {@code lower} and upper bound {@code upper}, the mean is * For lower bound {@code lower} and upper bound {@code upper}, the mean is
* {@code 0.5 * (lower + upper)}. * {@code 0.5 * (lower + upper)}.
*/ */
@ -115,7 +115,7 @@ public class UniformDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* For lower bound {@code lower} and upper bound {@code upper}, the * For lower bound {@code lower} and upper bound {@code upper}, the
* variance is {@code (upper - lower)^2 / 12}. * variance is {@code (upper - lower)^2 / 12}.
*/ */
@ -126,7 +126,7 @@ public class UniformDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The lower bound of the support is equal to the lower bound parameter * The lower bound of the support is equal to the lower bound parameter
* of the distribution. * of the distribution.
* *
@ -138,7 +138,7 @@ public class UniformDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The upper bound of the support is equal to the upper bound parameter * The upper bound of the support is equal to the upper bound parameter
* of the distribution. * of the distribution.
* *
@ -164,7 +164,7 @@ public class UniformDistribution extends BaseDistribution {
/** /**
* {@inheritDoc} * {@inheritDoc}
* <p/> * <p>
* The support of this distribution is connected. * The support of this distribution is connected.
* *
* @return {@code true} * @return {@code true}

View File

@ -58,7 +58,7 @@ public class NDArrayCreationUtil {
} }
/** Get an array of INDArrays (2d) all with the specified shape. Pair<INDArray,String> returned to aid /** Get an array of INDArrays (2d) all with the specified shape. {@code Pair<INDArray,String>} returned to aid
* debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments) * debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments)
* Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension, * Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension,
* etc to an original array. * etc to an original array.
@ -88,7 +88,7 @@ public class NDArrayCreationUtil {
* eg. rank 2: 1,1; 1,2; 2,1; 2,2; 3,4 * eg. rank 2: 1,1; 1,2; 2,1; 2,2; 3,4
* Motivated by TADs that often hit bugs when a "1" occurs as the size of a dimension * Motivated by TADs that often hit bugs when a "1" occurs as the size of a dimension
* *
* @param rank any rank including true scalars i.e rank >= 0 * @param rank any rank including true scalars i.e rank &gt;= 0
* @param order what order array to return i.e 'c' or 'f' order arrays * @param order what order array to return i.e 'c' or 'f' order arrays
* @return List of arrays and the shapes as strings * @return List of arrays and the shapes as strings
*/ */

View File

@ -355,7 +355,7 @@ public class AsyncDataSetIterator implements DataSetIterator {
* yet been called, or the {@code remove} method has already * yet been called, or the {@code remove} method has already
* been called after the last call to the {@code next} * been called after the last call to the {@code next}
* method * method
* @implSpec The default implementation throws an instance of * The default implementation throws an instance of
* {@link UnsupportedOperationException} and performs no other action. * {@link UnsupportedOperationException} and performs no other action.
*/ */
@Override @Override

View File

@ -299,7 +299,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator {
* yet been called, or the {@code remove} method has already * yet been called, or the {@code remove} method has already
* been called after the last call to the {@code next} * been called after the last call to the {@code next}
* method * method
* @implSpec The default implementation throws an instance of * The default implementation throws an instance of
* {@link UnsupportedOperationException} and performs no other action. * {@link UnsupportedOperationException} and performs no other action.
*/ */
@Override @Override

View File

@ -560,7 +560,6 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
/** /**
* @Deprecated
* Subtract by the column means and divide by the standard deviation * Subtract by the column means and divide by the standard deviation
*/ */
@Deprecated @Deprecated

View File

@ -117,7 +117,6 @@ public class KFoldIterator implements DataSetIterator {
/** /**
* Shuffles the dataset and resets to the first fold * Shuffles the dataset and resets to the first fold
* *
* @return void
*/ */
@Override @Override
public void reset() { public void reset() {
@ -129,7 +128,7 @@ public class KFoldIterator implements DataSetIterator {
/** /**
* The number of examples in every fold is (N / k), * The number of examples in every fold is (N / k),
* except when (N % k) > 0, when the first (N % k) folds contain (N / k) + 1 examples * except when (N % k) &gt; 0, when the first (N % k) folds contain (N / k) + 1 examples
* *
* @return examples in a fold * @return examples in a fold
*/ */

View File

@ -49,7 +49,6 @@ public class TestDataSetIterator implements DataSetIterator {
* Initializes with a default batch of 5 * Initializes with a default batch of 5
* *
* @param dataset the dataset to make the iterator from * @param dataset the dataset to make the iterator from
* @param batch the batchsize for the iterator
*/ */
public TestDataSetIterator(DataSet dataset) { public TestDataSetIterator(DataSet dataset) {
this(dataset, 5); this(dataset, 5);

View File

@ -65,9 +65,9 @@ public class RandomProjection {
* The minimum number n' of components to guarantee the eps-embedding is * The minimum number n' of components to guarantee the eps-embedding is
* given by: * given by:
* *
* n' >= 4 log(n) / (eps² / 2 - eps³ / 3) * {@code n' >= 4 log(n) / (eps² / 2 - eps³ / 3)}
* *
* see http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1 * http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1
* @param n Number of samples. If an array is given, it will compute * @param n Number of samples. If an array is given, it will compute
* a safe number of components array-wise. * a safe number of components array-wise.
* @param eps Maximum distortion rate as defined by the Johnson-Lindenstrauss lemma. * @param eps Maximum distortion rate as defined by the Johnson-Lindenstrauss lemma.

View File

@ -30,7 +30,6 @@ public interface EnvironmentalAction {
/** /**
* This method will be executed with corresponding Env Var value * This method will be executed with corresponding Env Var value
* *
* @param name
* @param value * @param value
*/ */
void process(String value); void process(String value);

View File

@ -276,7 +276,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
* Rotate a matrix 90 degrees * Rotate a matrix 90 degrees
* *
* @param toRotate the matrix to rotate * @param toRotate the matrix to rotate
* @return the rotated matrix
*/ */
@Override @Override
public void rot90(INDArray toRotate) { public void rot90(INDArray toRotate) {

View File

@ -33,7 +33,7 @@ public interface BlasWrapper {
*/ */
/** /**
* Compute x <-> y (swap two matrices) * Compute {@code x <-> y} (swap two matrices)
*/ */
INDArray swap(INDArray x, INDArray y); INDArray swap(INDArray x, INDArray y);
@ -69,14 +69,14 @@ public interface BlasWrapper {
INDArray scal(double alpha, INDArray x); INDArray scal(double alpha, INDArray x);
/** /**
* Compute x <- alpha * x (scale a matrix) * Compute {@code x <- alpha * x} (scale a matrix)
*/ */
@Deprecated @Deprecated
INDArray scal(float alpha, INDArray x); INDArray scal(float alpha, INDArray x);
/** /**
* Compute y <- x (copy a matrix) * Compute {@code y <- x} (copy a matrix)
*/ */
INDArray copy(INDArray x, INDArray y); INDArray copy(INDArray x, INDArray y);
@ -84,13 +84,13 @@ public interface BlasWrapper {
INDArray axpy(double da, INDArray dx, INDArray dy); INDArray axpy(double da, INDArray dx, INDArray dy);
/** /**
* Compute y <- alpha * x + y (elementwise addition) * Compute {@code y <- alpha * x + y }(elementwise addition)
*/ */
@Deprecated @Deprecated
INDArray axpy(float da, INDArray dx, INDArray dy); INDArray axpy(float da, INDArray dx, INDArray dy);
/** /**
* Compute y <- y + x * alpha * Compute {@code y <- y + x * alpha}
* @param da the alpha to multiply by * @param da the alpha to multiply by
* @param dx * @param dx
* @param dy * @param dy
@ -130,7 +130,7 @@ public interface BlasWrapper {
INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y); INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y);
/** /**
* Compute y <- alpha*op(a)*x + beta * y (general matrix vector * Compute {@code y <- alpha*op(a)*x + beta * y} (general matrix vector
* multiplication) * multiplication)
*/ */
@Deprecated @Deprecated
@ -142,7 +142,7 @@ public interface BlasWrapper {
INDArray ger(double alpha, INDArray x, INDArray y, INDArray a); INDArray ger(double alpha, INDArray x, INDArray y, INDArray a);
/** /**
* Compute A <- alpha * x * y^T + A (general rank-1 update) * Compute {@code A <- alpha * x * y^T + A} (general rank-1 update)
*/ */
INDArray ger(float alpha, INDArray x, INDArray y, INDArray a); INDArray ger(float alpha, INDArray x, INDArray y, INDArray a);
@ -193,14 +193,14 @@ public interface BlasWrapper {
/** /**
* Generalized Least Squares via *GELSD. * Generalized Least Squares via *GELSD.
* <p/> * <p>
* Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows * Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows
* than columns. * than columns.
* <p/> * <p>
* For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m &lt; n, since B is overwritten to contain
* the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix. * the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix.
* <p/> * <p>
* Likewise, if m > n, the solution consists only of the first n rows of B. * Likewise, if m &gt; n, the solution consists only of the first n rows of B.
* *
* @param A an (m,n) matrix * @param A an (m,n) matrix
* @param B an (max(m,n), k) matrix (well, at least) * @param B an (max(m,n), k) matrix (well, at least)

View File

@ -193,7 +193,6 @@ public interface NDArrayFactory {
* Rotate a matrix 90 degrees * Rotate a matrix 90 degrees
* *
* @param toRotate the matrix to rotate * @param toRotate the matrix to rotate
* @return the rotated matrix
*/ */
void rot90(INDArray toRotate); void rot90(INDArray toRotate);
@ -340,7 +339,6 @@ public interface NDArrayFactory {
* *
* @param array the ndarray to shuffle * @param array the ndarray to shuffle
* @param dimension the dimension to do the shuffle * @param dimension the dimension to do the shuffle
* @return
*/ */
void shuffle(INDArray array, Random rnd, int... dimension); void shuffle(INDArray array, Random rnd, int... dimension);
@ -350,7 +348,6 @@ public interface NDArrayFactory {
* *
* @param array the ndarray to shuffle * @param array the ndarray to shuffle
* @param dimension the dimension to do the shuffle * @param dimension the dimension to do the shuffle
* @return
*/ */
void shuffle(Collection<INDArray> array, Random rnd, int... dimension); void shuffle(Collection<INDArray> array, Random rnd, int... dimension);
@ -360,7 +357,6 @@ public interface NDArrayFactory {
* *
* @param array the ndarray to shuffle * @param array the ndarray to shuffle
* @param dimensions the dimensions to do the shuffle * @param dimensions the dimensions to do the shuffle
* @return
*/ */
void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions); void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions);
@ -1370,9 +1366,9 @@ public interface NDArrayFactory {
INDArray createFromNpyFile(File file); INDArray createFromNpyFile(File file);
/** /**
* Create a Map<String, INDArray> from given npz file. * Create a {@code Map<String, INDArray>} from given npz file.
* @param file the file to create the map from * @param file the file to create the map from
* @return Map<String, INDArray> * @return {@code Map<String, INDArray>}
*/ */
Map<String, INDArray> createFromNpzFile(File file) throws Exception; Map<String, INDArray> createFromNpzFile(File file) throws Exception;
@ -1386,7 +1382,7 @@ public interface NDArrayFactory {
* *
* *
* @param array the array to convert * @param array the array to convert
* @returnthe created pointer representing * @return the created pointer representing
* a pointer to a numpy header * a pointer to a numpy header
*/ */
Pointer convertToNumpy(INDArray array); Pointer convertToNumpy(INDArray array);

View File

@ -1441,7 +1441,7 @@ public class Nd4j {
} }
/** /**
* See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype. * See {@link #createBuffer(DataType dataType, long length, boolean initialize)} with default datatype.
*/ */
public static DataBuffer createBuffer(long length, boolean initialize) { public static DataBuffer createBuffer(long length, boolean initialize) {
return createBuffer(Nd4j.dataType(), length, initialize); return createBuffer(Nd4j.dataType(), length, initialize);
@ -2828,7 +2828,7 @@ public class Nd4j {
} }
/** /**
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)) * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)}
*/ */
@Deprecated @Deprecated
public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) { public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) {
@ -3306,7 +3306,7 @@ public class Nd4j {
* Generate an array with random values generated according to a binomial distribution with the specified * Generate an array with random values generated according to a binomial distribution with the specified
* number of trials and probability * number of trials and probability
* *
* @param nTrials Number of trials. Must be >= 0 * @param nTrials Number of trials. Must be &gt;= 0
* @param p Probability. Must be in range 0 to 1 * @param p Probability. Must be in range 0 to 1
* @param shape Shape of the result array * @param shape Shape of the result array
* @return Result array * @return Result array
@ -3319,7 +3319,7 @@ public class Nd4j {
* Fill the target array with random values generated according to a binomial distribution with the specified * Fill the target array with random values generated according to a binomial distribution with the specified
* number of trials and probability * number of trials and probability
* *
* @param nTrials Number of trials. Must be >= 0 * @param nTrials Number of trials. Must be &gt;= 0
* @param p Probability. Must be in range 0 to 1 * @param p Probability. Must be in range 0 to 1
* @param target Result array * @param target Result array
* @return Result array * @return Result array
@ -3333,7 +3333,7 @@ public class Nd4j {
/** /**
* Exponential distribution: P(x) = lambda * exp(-lambda * x) * Exponential distribution: P(x) = lambda * exp(-lambda * x)
* *
* @param lambda Must be > 0 * @param lambda Must be &gt; 0
* @param shape Shape of the array to generate * @param shape Shape of the array to generate
*/ */
public static INDArray randomExponential(double lambda, long... shape) { public static INDArray randomExponential(double lambda, long... shape) {
@ -3341,9 +3341,9 @@ public class Nd4j {
} }
/** /**
* Exponential distribution: P(x) = lambda * exp(-lambda * x) * Exponential distribution: {@code P(x) = lambda * exp(-lambda * x)}
* *
* @param lambda Must be > 0 * @param lambda Must be &gt; 0
* @param target Array to hold the result * @param target Array to hold the result
*/ */
public static INDArray randomExponential(double lambda, INDArray target) { public static INDArray randomExponential(double lambda, INDArray target) {
@ -3925,7 +3925,7 @@ public class Nd4j {
} }
/** /**
* See {@link @see #create(int, int, int[], char)} * See {@link Nd4j#create(int, int, int[], char)}
*/ */
public static INDArray zeros(int rows, int columns, int[] stride) { public static INDArray zeros(int rows, int columns, int[] stride) {
return create(rows, columns, stride, order()); return create(rows, columns, stride, order());
@ -4630,7 +4630,7 @@ public class Nd4j {
/** /**
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br> * Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -&gt; [2,3]
* *
* @param arrs Arrays to vstack * @param arrs Arrays to vstack
*/ */
@ -4646,7 +4646,7 @@ public class Nd4j {
/** /**
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br> * Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -&gt; [2,3]
* *
* @param arrs Arrays to vstack * @param arrs Arrays to vstack
*/ */
@ -5462,7 +5462,7 @@ public class Nd4j {
Examples Examples
-------- --------
>>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) {@code >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1)
array([[ 1, 2, 3], array([[ 1, 2, 3],
[ 4, 5, 6], [ 4, 5, 6],
[ 0, 8, 9], [ 0, 8, 9],
@ -5473,6 +5473,7 @@ public class Nd4j {
mask = tri(*m.shape[-2:], k=k-1, dtype=bool) mask = tri(*m.shape[-2:], k=k-1, dtype=bool)
return where(mask, zeros(1, m.dtype), m) return where(mask, zeros(1, m.dtype), m)
}
* @param m source array * @param m source array
* @param k to zero below the k-th diagonal * @param k to zero below the k-th diagonal
@ -5517,8 +5518,8 @@ public class Nd4j {
* @param n number of rows in the array * @param n number of rows in the array
* @param m number of columns in the array ( can be just equal to n) * @param m number of columns in the array ( can be just equal to n)
* @param k The sub-diagonal at and below which the array is filled. * @param k The sub-diagonal at and below which the array is filled.
`k` = 0 is the main diagonal, while `k` < 0 is below it, `k` = 0 is the main diagonal, while `k` &gt; 0 is below it,
and `k` > 0 is above. The default is 0. and `k` &gt; 0 is above. The default is 0.
* @return array with ones at and below the given diagonal and zeros elsewhere * @return array with ones at and below the given diagonal and zeros elsewhere
*/ */
public static INDArray tri(int n,int m,int k) { public static INDArray tri(int n,int m,int k) {

View File

@ -269,14 +269,14 @@ public abstract class Nd4jBackend {
/** /**
* Constructs a new exception with the specified cause and a detail * Constructs a new exception with the specified cause and a detail
* message of <tt>(cause==null ? null : cause.toString())</tt> (which * message of {@code (cause==null ? null : cause.toString())} (which
* typically contains the class and detail message of <tt>cause</tt>). * typically contains the class and detail message of cause).
* This constructor is useful for exceptions that are little more than * This constructor is useful for exceptions that are little more than
* wrappers for other throwables (for example, {@link * wrappers for other throwables (for example, {@link
* PrivilegedActionException}). * PrivilegedActionException}).
* *
* @param cause the cause (which is saved for later retrieval by the * @param cause the cause (which is saved for later retrieval by the
* {@link #getCause()} method). (A <tt>null</tt> value is * {@link #getCause()} method). (A null value is
* permitted, and indicates that the cause is nonexistent or * permitted, and indicates that the cause is nonexistent or
* unknown.) * unknown.)
* @since 1.4 * @since 1.4

View File

@ -173,13 +173,13 @@ public class Indices {
/** /**
* Fill in the missing indices to be the * Fill in the missing indices to be the
* same length as the original shape. * same length as the original shape.
* <p/> * <p>
* Think of this as what fills in the indices for numpy or matlab: * Think of this as what fills in the indices for numpy or matlab:
* Given a which is (4,3,2) in numpy: * Given a which is (4,3,2) in numpy:
* <p/> * <p>
* a[1:3] is filled in by the rest * a[1:3] is filled in by the rest
* to give back the full slice * to give back the full slice
* <p/> * <p>
* This algorithm fills in that delta * This algorithm fills in that delta
* *
* @param shape the original shape * @param shape the original shape
@ -244,7 +244,7 @@ public class Indices {
/** /**
* Calculate the shape for the given set of indices. * Calculate the shape for the given set of indices.
* <p/> * <p>
* The shape is defined as (for each dimension) * The shape is defined as (for each dimension)
* the difference between the end index + 1 and * the difference between the end index + 1 and
* the begin index * the begin index
@ -344,12 +344,12 @@ public class Indices {
/** /**
* Calculate the shape for the given set of indices and offsets. * Calculate the shape for the given set of indices and offsets.
* <p/> * <p>
* The shape is defined as (for each dimension) * The shape is defined as (for each dimension)
* the difference between the end index + 1 and * the difference between the end index + 1 and
* the begin index * the begin index
* <p/> * <p>
* If specified, this will check for whether any of the indices are >= to end - 1 * If specified, this will check for whether any of the indices are &gt;= to end - 1
* and if so, prune it down * and if so, prune it down
* *
* @param shape the original shape * @param shape the original shape

View File

@ -90,7 +90,6 @@ public class AdaBeliefUpdater implements GradientUpdater<AdaBelief> {
* *
* @param gradient the gradient to get the update for * @param gradient the gradient to get the update for
* @param iteration * @param iteration
* @return the gradient
*/ */
@Override @Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) { public void applyUpdater(INDArray gradient, int iteration, int epoch) {

View File

@ -32,78 +32,84 @@ import java.util.Map;
@Data @Data
public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> { public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
public static final String MSG_STATE = "msg";
public static final String MSDX_STATE = "msdx";
private final AdaDelta config; public static final String MSG_STATE = "msg";
public static final String MSDX_STATE = "msdx";
private INDArray msg; //E[g^2]_t by arxiv paper, algorithm 1 private final AdaDelta config;
private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1
private INDArray msg; //E[g^2]_t by arxiv paper, algorithm 1
private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1
public AdaDeltaUpdater(AdaDelta config) {
this.config = config;
}
public AdaDeltaUpdater(AdaDelta config) { @Override
this.config = config; public void setState(Map<String, INDArray> stateMap, boolean initialize) {
if (!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE)
|| stateMap.size() != 2) {
throw new IllegalStateException(
"State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys "
+ stateMap.keySet());
} }
this.msg = stateMap.get(MSG_STATE);
this.msdx = stateMap.get(MSDX_STATE);
}
@Override @Override
public void setState(Map<String, INDArray> stateMap, boolean initialize) { public Map<String, INDArray> getState() {
if(!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE) || stateMap.size() != 2){ Map<String, INDArray> r = new HashMap<>();
throw new IllegalStateException("State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys " + stateMap.keySet()); r.put(MSG_STATE, msg);
} r.put(MSDX_STATE, msdx);
this.msg = stateMap.get(MSG_STATE); return r;
this.msdx = stateMap.get(MSDX_STATE); }
}
@Override @Override
public Map<String, INDArray> getState() { public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder,
Map<String,INDArray> r = new HashMap<>(); boolean initialize) {
r.put(MSG_STATE, msg); if (!viewArray.isRowVector()) {
r.put(MSDX_STATE, msdx); throw new IllegalArgumentException("Invalid input: expect row vector input");
return r; }
} if (initialize) {
viewArray.assign(0);
}
long length = viewArray.length();
this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
@Override //Reshape to match the expected shape of the input gradient arrays
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f');
if (!viewArray.isRowVector()) this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f');
throw new IllegalArgumentException("Invalid input: expect row vector input"); if (msg == null || msdx == null) {
if (initialize) throw new IllegalStateException("Could not correctly reshape gradient view arrays");
viewArray.assign(0); }
long length = viewArray.length(); }
this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
//Reshape to match the expected shape of the input gradient arrays /**
this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f'); * Get the updated gradient for the given gradient and also update the state of ada delta.
this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f'); *
if (msg == null || msdx == null) * @param gradient the gradient to get the updated gradient for
throw new IllegalStateException("Could not correctly reshape gradient view arrays"); * @param iteration
} */
@Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
if (msg == null || msdx == null) {
throw new IllegalStateException("Updater has not been initialized with view state");
}
/** double rho = config.getRho();
* Get the updated gradient for the given gradient double epsilon = config.getEpsilon();
* and also update the state of ada delta.
*
* @param gradient the gradient to get the
* updated gradient for
* @param iteration
* @return the update gradient
*/
@Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
if (msg == null || msdx == null)
throw new IllegalStateException("Updater has not been initialized with view state");
double rho = config.getRho(); //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf
double epsilon = config.getEpsilon(); //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t
//Calculate update:
//dX = - g * RMS[delta x]_{t-1} / RMS[g]_t
//Note: negative is applied in the DL4J step function: params -= update rather than params += update
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
//Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho,
//E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t epsilon));
//Calculate update: }
//dX = - g * RMS[delta x]_{t-1} / RMS[g]_t
//Note: negative is applied in the DL4J step function: params -= update rather than params += update
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon));
}
} }

View File

@ -92,7 +92,6 @@ public class AdaMaxUpdater implements GradientUpdater<AdaMax> {
* *
* @param gradient the gradient to get the update for * @param gradient the gradient to get the update for
* @param iteration * @param iteration
* @return the gradient
*/ */
@Override @Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) { public void applyUpdater(INDArray gradient, int iteration, int epoch) {

View File

@ -93,7 +93,6 @@ public class AdamUpdater implements GradientUpdater<Adam> {
* *
* @param gradient the gradient to get the update for * @param gradient the gradient to get the update for
* @param iteration * @param iteration
* @return the gradient
*/ */
@Override @Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) { public void applyUpdater(INDArray gradient, int iteration, int epoch) {

View File

@ -48,7 +48,6 @@ public interface GradientUpdater<T extends IUpdater> {
* *
* @param gradient the gradient to modify * @param gradient the gradient to modify
* @param iteration * @param iteration
* @return the modified gradient
*/ */
void applyUpdater(INDArray gradient, int iteration, int epoch); void applyUpdater(INDArray gradient, int iteration, int epoch);
} }

View File

@ -92,7 +92,6 @@ public class NadamUpdater implements GradientUpdater<Nadam> {
* *
* @param gradient the gradient to get the update for * @param gradient the gradient to get the update for
* @param iteration * @param iteration
* @return the gradient
*/ */
@Override @Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) { public void applyUpdater(INDArray gradient, int iteration, int epoch) {

View File

@ -77,7 +77,6 @@ public class NesterovsUpdater implements GradientUpdater<Nesterovs> {
* *
* @param gradient the gradient to get the update for * @param gradient the gradient to get the update for
* @param iteration * @param iteration
* @return
*/ */
@Override @Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) { public void applyUpdater(INDArray gradient, int iteration, int epoch) {

View File

@ -152,12 +152,12 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
/** /**
* Returns the value to which the specified key is mapped, * Returns the value to which the specified key is mapped,
* or {@code null} if this map contains no mapping for the key. * or {@code null} if this map contains no mapping for the key.
* <p/> * <p>
* <p>More formally, if this map contains a mapping from a key * <p>More formally, if this map contains a mapping from a key
* {@code k} to a value {@code v} such that {@code (key==null ? k==null : * {@code k} to a value {@code v} such that {@code (key==null ? k==null :
* key.equals(k))}, then this method returns {@code v}; otherwise * key.equals(k))}, then this method returns {@code v}; otherwise
* it returns {@code null}. (There can be at most one such mapping.) * it returns {@code null}. (There can be at most one such mapping.)
* <p/> * <p>
* <p>If this map permits null values, then a return value of * <p>If this map permits null values, then a return value of
* {@code null} does not <i>necessarily</i> indicate that the map * {@code null} does not <i>necessarily</i> indicate that the map
* contains no mapping for the key; it's also possible that the map * contains no mapping for the key; it's also possible that the map
@ -214,15 +214,15 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
* from key <tt>k</tt> to value <tt>v</tt> such that * from key <tt>k</tt> to value <tt>v</tt> such that
* <code>(key==null ? k==null : key.equals(k))</code>, that mapping * <code>(key==null ? k==null : key.equals(k))</code>, that mapping
* is removed. (The map can contain at most one such mapping.) * is removed. (The map can contain at most one such mapping.)
* <p/> * <p>
* <p>Returns the value to which this map previously associated the key, * <p>Returns the value to which this map previously associated the key,
* or <tt>null</tt> if the map contained no mapping for the key. * or <tt>null</tt> if the map contained no mapping for the key.
* <p/> * <p>
* <p>If this map permits null values, then a return value of * <p>If this map permits null values, then a return value of
* <tt>null</tt> does not <i>necessarily</i> indicate that the map * <tt>null</tt> does not <i>necessarily</i> indicate that the map
* contained no mapping for the key; it's also possible that the map * contained no mapping for the key; it's also possible that the map
* explicitly mapped the key to <tt>null</tt>. * explicitly mapped the key to <tt>null</tt>.
* <p/> * <p>
* <p>The map will not contain a mapping for the specified key once the * <p>The map will not contain a mapping for the specified key once the
* call returns. * call returns.
* *

View File

@ -108,12 +108,12 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
* If this applyTransformToDestination makes any guarantees as to what order its elements * If this applyTransformToDestination makes any guarantees as to what order its elements
* are returned by its iterator, this method must return the * are returned by its iterator, this method must return the
* elements in the same order. * elements in the same order.
* <p/> * <p>
* <p>The returned array will be "safe" in that no references to it * <p>The returned array will be "safe" in that no references to it
* are maintained by this applyTransformToDestination. (In other words, this method must * are maintained by this applyTransformToDestination. (In other words, this method must
* allocate a new array even if this applyTransformToDestination is backed by an array). * allocate a new array even if this applyTransformToDestination is backed by an array).
* The caller is thus free to modify the returned array. * The caller is thus free to modify the returned array.
* <p/> * <p>
* <p>This method acts as bridge between array-based and collection-based * <p>This method acts as bridge between array-based and collection-based
* APIs. * APIs.
* *
@ -130,27 +130,27 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
* If the applyTransformToDestination fits in the specified array, it is returned therein. * If the applyTransformToDestination fits in the specified array, it is returned therein.
* Otherwise, a new array is allocated with the runtime type of the * Otherwise, a new array is allocated with the runtime type of the
* specified array and the size of this applyTransformToDestination. * specified array and the size of this applyTransformToDestination.
* <p/> * <p>
* <p>If this applyTransformToDestination fits in the specified array with room to spare * <p>If this applyTransformToDestination fits in the specified array with room to spare
* (i.e., the array has more elements than this applyTransformToDestination), the element in * (i.e., the array has more elements than this applyTransformToDestination), the element in
* the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to * the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to
* <tt>null</tt>. (This is useful in determining the length of this * <tt>null</tt>. (This is useful in determining the length of this
* applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain * applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain
* any null elements.) * any null elements.)
* <p/> * <p>
* <p>If this applyTransformToDestination makes any guarantees as to what order its elements * <p>If this applyTransformToDestination makes any guarantees as to what order its elements
* are returned by its iterator, this method must return the elements * are returned by its iterator, this method must return the elements
* in the same order. * in the same order.
* <p/> * <p>
* <p>Like the {@link #toArray()} method, this method acts as bridge between * <p>Like the {@link #toArray()} method, this method acts as bridge between
* array-based and collection-based APIs. Further, this method allows * array-based and collection-based APIs. Further, this method allows
* precise control over the runtime type of the output array, and may, * precise control over the runtime type of the output array, and may,
* under certain circumstances, be used to save allocation costs. * under certain circumstances, be used to save allocation costs.
* <p/> * <p>
* <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings. * <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings.
* The following code can be used to dump the applyTransformToDestination into a newly allocated * The following code can be used to dump the applyTransformToDestination into a newly allocated
* array of <tt>String</tt>: * array of <tt>String</tt>:
* <p/> * <p>
* <pre> * <pre>
* String[] y = x.toArray(new String[0]);</pre> * String[] y = x.toArray(new String[0]);</pre>
* *
@ -181,7 +181,7 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
* unchanged and returns <tt>false</tt>. In combination with the * unchanged and returns <tt>false</tt>. In combination with the
* restriction on constructors, this ensures that sets never contain * restriction on constructors, this ensures that sets never contain
* duplicate elements. * duplicate elements.
* <p/> * <p>
* <p>The stipulation above does not imply that sets must accept all * <p>The stipulation above does not imply that sets must accept all
* elements; sets may refuse to add any particular element, including * elements; sets may refuse to add any particular element, including
* <tt>null</tt>, and throw an exception, as described in the * <tt>null</tt>, and throw an exception, as described in the

View File

@ -204,9 +204,9 @@ public class ArrayUtil {
/** /**
* Credit to mikio braun from jblas * Credit to mikio braun from jblas
* <p/> * <p>
* Create a random permutation of the numbers 0, ..., size - 1. * Create a random permutation of the numbers 0, ..., size - 1.
* <p/> * <p>
* see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145 * see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145
*/ */
public static int[] randomPermutation(int size) { public static int[] randomPermutation(int size) {

View File

@ -64,7 +64,7 @@ public class HelperUtils {
if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) { if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) {
if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) { if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) {
log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName); log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName);
helperRet = (LayerHelper) DL4JClassLoading.<LayerHelper>createNewInstance( helperRet = DL4JClassLoading.<LayerHelper>createNewInstance(
cudnnHelperClassName, cudnnHelperClassName,
(Class<? super LayerHelper>) layerHelperSuperClass, (Class<? super LayerHelper>) layerHelperSuperClass,
new Object[]{arguments}); new Object[]{arguments});
@ -76,7 +76,7 @@ public class HelperUtils {
ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader(); ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader();
DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass); DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass);
try { try {
helperRet = (LayerHelper) DL4JClassLoading.<LayerHelper>createNewInstance( helperRet = DL4JClassLoading.<LayerHelper>createNewInstance(
cudnnHelperClassName, cudnnHelperClassName,
(Class<? super LayerHelper>) layerHelperSuperClass, (Class<? super LayerHelper>) layerHelperSuperClass,
arguments); arguments);
@ -99,7 +99,7 @@ public class HelperUtils {
} }
} else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) { } else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) {
helperRet = DL4JClassLoading.<LayerHelper>createNewInstance( helperRet = DL4JClassLoading.createNewInstance(
oneDnnClassName, oneDnnClassName,
arguments); arguments);
log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName); log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName);

View File

@ -3,11 +3,14 @@ plugins {
id 'maven-publish' id 'maven-publish'
} }
/*
configurations.archives.artifacts.with { archives -> configurations.archives.artifacts.with { archives ->
archives.each { archives.each {
println(it.name) println(it.name)
} }
} }
*/
dependencies { dependencies {
//Todo clean this //Todo clean this
@ -19,7 +22,7 @@ dependencies {
//TODO for the two below.. either platform specific uber jars or a single big one with all platforms //TODO for the two below.. either platform specific uber jars or a single big one with all platforms
api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64"
//api group: "org.bytedeco", name: "javacpp", version: "1.5.7" //api group: "org.bytedeco", name: "javacpp", version: "1.5.7"
api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" // api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu"
//api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT'
rootProject.getAllprojects().each { Project sproj -> rootProject.getAllprojects().each { Project sproj ->
if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")

View File

@ -212,7 +212,7 @@ tasks.withType(org.bytedeco.gradle.javacpp.BuildTask) {
// Disable the standard javacpp generated tasks and use own // Disable the standard javacpp generated tasks and use own
// versions below. This allows to build for each variant // versions below. This allows to build for each variant
[javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each { [javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each {
it.enabled false; it.enabled false
} }
chipList.each { thisChip -> chipList.each { thisChip ->

View File

@ -152,12 +152,12 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
/** /**
* Returns the value to which the specified key is mapped, * Returns the value to which the specified key is mapped,
* or {@code null} if this map contains no mapping for the key. * or {@code null} if this map contains no mapping for the key.
* <p/> * <p>
* <p>More formally, if this map contains a mapping from a key * <p>More formally, if this map contains a mapping from a key
* {@code k} to a value {@code v} such that {@code (key==null ? k==null : * {@code k} to a value {@code v} such that {@code (key==null ? k==null :
* key.equals(k))}, then this method returns {@code v}; otherwise * key.equals(k))}, then this method returns {@code v}; otherwise
* it returns {@code null}. (There can be at most one such mapping.) * it returns {@code null}. (There can be at most one such mapping.)
* <p/> * <p>
* <p>If this map permits null values, then a return value of * <p>If this map permits null values, then a return value of
* {@code null} does not <i>necessarily</i> indicate that the map * {@code null} does not <i>necessarily</i> indicate that the map
* contains no mapping for the key; it's also possible that the map * contains no mapping for the key; it's also possible that the map
@ -214,15 +214,15 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
* from key <tt>k</tt> to value <tt>v</tt> such that * from key <tt>k</tt> to value <tt>v</tt> such that
* <code>(key==null ? k==null : key.equals(k))</code>, that mapping * <code>(key==null ? k==null : key.equals(k))</code>, that mapping
* is removed. (The map can contain at most one such mapping.) * is removed. (The map can contain at most one such mapping.)
* <p/> * <p>
* <p>Returns the value to which this map previously associated the key, * <p>Returns the value to which this map previously associated the key,
* or <tt>null</tt> if the map contained no mapping for the key. * or <tt>null</tt> if the map contained no mapping for the key.
* <p/> * <p>
* <p>If this map permits null values, then a return value of * <p>If this map permits null values, then a return value of
* <tt>null</tt> does not <i>necessarily</i> indicate that the map * <tt>null</tt> does not <i>necessarily</i> indicate that the map
* contained no mapping for the key; it's also possible that the map * contained no mapping for the key; it's also possible that the map
* explicitly mapped the key to <tt>null</tt>. * explicitly mapped the key to <tt>null</tt>.
* <p/> * <p>
* <p>The map will not contain a mapping for the specified key once the * <p>The map will not contain a mapping for the specified key once the
* call returns. * call returns.
* *

View File

@ -108,12 +108,12 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
* If this applyTransformToDestination makes any guarantees as to what order its elements * If this applyTransformToDestination makes any guarantees as to what order its elements
* are returned by its iterator, this method must return the * are returned by its iterator, this method must return the
* elements in the same order. * elements in the same order.
* <p/> * <p>
* <p>The returned array will be "safe" in that no references to it * <p>The returned array will be "safe" in that no references to it
* are maintained by this applyTransformToDestination. (In other words, this method must * are maintained by this applyTransformToDestination. (In other words, this method must
* allocate a new array even if this applyTransformToDestination is backed by an array). * allocate a new array even if this applyTransformToDestination is backed by an array).
* The caller is thus free to modify the returned array. * The caller is thus free to modify the returned array.
* <p/> * <p>
* <p>This method acts as bridge between array-based and collection-based * <p>This method acts as bridge between array-based and collection-based
* APIs. * APIs.
* *
@ -130,27 +130,27 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
* If the applyTransformToDestination fits in the specified array, it is returned therein. * If the applyTransformToDestination fits in the specified array, it is returned therein.
* Otherwise, a new array is allocated with the runtime type of the * Otherwise, a new array is allocated with the runtime type of the
* specified array and the size of this applyTransformToDestination. * specified array and the size of this applyTransformToDestination.
* <p/> * <p>
* <p>If this applyTransformToDestination fits in the specified array with room to spare * <p>If this applyTransformToDestination fits in the specified array with room to spare
* (i.e., the array has more elements than this applyTransformToDestination), the element in * (i.e., the array has more elements than this applyTransformToDestination), the element in
* the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to * the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to
* <tt>null</tt>. (This is useful in determining the length of this * <tt>null</tt>. (This is useful in determining the length of this
* applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain * applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain
* any null elements.) * any null elements.)
* <p/> * <p>
* <p>If this applyTransformToDestination makes any guarantees as to what order its elements * <p>If this applyTransformToDestination makes any guarantees as to what order its elements
* are returned by its iterator, this method must return the elements * are returned by its iterator, this method must return the elements
* in the same order. * in the same order.
* <p/> * <p>
* <p>Like the {@link #toArray()} method, this method acts as bridge between * <p>Like the {@link #toArray()} method, this method acts as bridge between
* array-based and collection-based APIs. Further, this method allows * array-based and collection-based APIs. Further, this method allows
* precise control over the runtime type of the output array, and may, * precise control over the runtime type of the output array, and may,
* under certain circumstances, be used to save allocation costs. * under certain circumstances, be used to save allocation costs.
* <p/> * <p>
* <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings. * <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings.
* The following code can be used to dump the applyTransformToDestination into a newly allocated * The following code can be used to dump the applyTransformToDestination into a newly allocated
* array of <tt>String</tt>: * array of <tt>String</tt>:
* <p/> * <p>
* <pre> * <pre>
* String[] y = x.toArray(new String[0]);</pre> * String[] y = x.toArray(new String[0]);</pre>
* *
@ -181,7 +181,7 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
* unchanged and returns <tt>false</tt>. In combination with the * unchanged and returns <tt>false</tt>. In combination with the
* restriction on constructors, this ensures that sets never contain * restriction on constructors, this ensures that sets never contain
* duplicate elements. * duplicate elements.
* <p/> * <p>
* <p>The stipulation above does not imply that sets must accept all * <p>The stipulation above does not imply that sets must accept all
* elements; sets may refuse to add any particular element, including * elements; sets may refuse to add any particular element, including
* <tt>null</tt>, and throw an exception, as described in the * <tt>null</tt>, and throw an exception, as described in the

View File

@ -204,9 +204,9 @@ public class ArrayUtil {
/** /**
* Credit to mikio braun from jblas * Credit to mikio braun from jblas
* <p/> * <p>
* Create a random permutation of the numbers 0, ..., size - 1. * Create a random permutation of the numbers 0, ..., size - 1.
* <p/> * <p>
* see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145 * see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145
*/ */
public static int[] randomPermutation(int size) { public static int[] randomPermutation(int size) {

View File

@ -148,7 +148,6 @@ include ':cavis-ui:cavis-ui-standalone'
include ':cavis-ui:cavis-ui-vertx' include ':cavis-ui:cavis-ui-vertx'
include ':cavis-zoo' include ':cavis-zoo'
include ':cavis-zoo:cavis-zoo-models' include ':cavis-zoo:cavis-zoo-models'
include ':brutex-extended-tests' include ':brutex-extended-tests'
include ':cavis-full' include ':cavis-full'