Fix javadoc and cleanup
parent
2b4d44ea67
commit
f92b786836
19
build.gradle
19
build.gradle
|
@ -55,6 +55,7 @@ configurations.all {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
allprojects { Project proj ->
|
allprojects { Project proj ->
|
||||||
apply plugin: 'com.google.osdetector'
|
apply plugin: 'com.google.osdetector'
|
||||||
|
|
||||||
|
@ -161,3 +162,21 @@ allprojects { Project proj ->
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
task aggregatedJavadocs(type: Javadoc, description: 'Generate javadocs from all child projects as if it was a single project', group: 'Documentation') {
|
||||||
|
subprojects.each { proj ->
|
||||||
|
proj.tasks.withType(Javadoc).each { javadocTask ->
|
||||||
|
logger.quiet("Adding javadoc for project " + proj.name)
|
||||||
|
source += javadocTask.source
|
||||||
|
classpath += javadocTask.classpath
|
||||||
|
excludes += javadocTask.excludes
|
||||||
|
includes += javadocTask.includes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
destinationDir = file("$buildDir/docs/javadoc")
|
||||||
|
title = "$project.name $version API"
|
||||||
|
options.author true
|
||||||
|
options.links 'http://docs.oracle.com/javase/8/docs/api/'
|
||||||
|
options.addStringOption('Xdoclint:none', '-quiet')
|
||||||
|
}
|
|
@ -130,3 +130,19 @@ echo "nameserver 8.8.8.8" | sudo tee -a /etc/resolv.conf
|
||||||
|
|
||||||
-P\<xxx>\
|
-P\<xxx>\
|
||||||
CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2
|
CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2
|
||||||
|
|
||||||
|
# Zeppelin Spark dependencies #
|
||||||
|
3
|
||||||
|
|
||||||
|
|
||||||
|
To add the dependency to the language models, use the following format in the Dependencies section of the of the Spark Interpreter configuration (Interpreters -> Spark -> Edit -> Dependencies):
|
||||||
|
|
||||||
|
groupId:artifactId:packaging:classifier:version
|
||||||
|
|
||||||
|
In your case it should work with
|
||||||
|
|
||||||
|
edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0
|
||||||
|
|
||||||
|
|
||||||
|
Native cpu code under linux needs libc6-dev
|
||||||
|
/lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found
|
|
@ -90,7 +90,7 @@ public class Operands {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns array identified its numeric id
|
* This method returns array identified its numeric id
|
||||||
* @param name
|
* @param id
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public INDArray getById(int id) {
|
public INDArray getById(int id) {
|
||||||
|
@ -99,7 +99,8 @@ public class Operands {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns array identified its numeric id and index
|
* This method returns array identified its numeric id and index
|
||||||
* @param name
|
* @param id
|
||||||
|
* @param index
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public INDArray getById(int id, int index) {
|
public INDArray getById(int id, int index) {
|
||||||
|
@ -121,7 +122,7 @@ public class Operands {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns contents of this entity as collection of key->value pairs
|
* This method returns contents of this entity as collection of key->value pairs
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Collection<Pair<NodeDescriptor, INDArray>> asCollection() {
|
public Collection<Pair<NodeDescriptor, INDArray>> asCollection() {
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class ExecDebuggingListener extends BaseListener {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param printMode Print mode, see {@link PrintMode}
|
* @param printMode Print mode, see {@link PrintMode}
|
||||||
* @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations"
|
* @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations"
|
||||||
* @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output
|
* @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output
|
||||||
*/
|
*/
|
||||||
public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){
|
public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){
|
||||||
|
|
|
@ -573,7 +573,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the function by the {@link DifferentialFunction#getOwnName()}
|
* Get the function by the {@link org.nd4j.autodiff.functions.DifferentialFunction#getOwnName()}
|
||||||
*
|
*
|
||||||
* @param id the id of the function
|
* @param id the id of the function
|
||||||
* @return the function for the given id if it exists
|
* @return the function for the given id if it exists
|
||||||
|
@ -1348,9 +1348,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
/**
|
/**
|
||||||
* Get the names of variables (if any) that have been marked as loss variables to be minimized.<br>
|
* Get the names of variables (if any) that have been marked as loss variables to be minimized.<br>
|
||||||
* Variables can be marked as loss variables in a few different ways:<br>
|
* Variables can be marked as loss variables in a few different ways:<br>
|
||||||
* (a) Losses are automatically added when creating loss functions via {@link #sd()}<br>
|
* (a) Losses are automatically added when creating loss functions via {@link SameDiff#sd}<br>
|
||||||
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}<br>
|
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}<br>
|
||||||
* (c) Via {@link TrainingConfig#setLossVariables(List)}<br>
|
* (c) Via {@link org.nd4j.autodiff.samediff.TrainingConfig#setLossVariables(List)}<br>
|
||||||
*/
|
*/
|
||||||
public List<String> getLossVariables() {
|
public List<String> getLossVariables() {
|
||||||
return Collections.unmodifiableList(this.lossVariables);
|
return Collections.unmodifiableList(this.lossVariables);
|
||||||
|
|
|
@ -54,12 +54,12 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A new map where the dependents (i.e., Y in "X -> Y") are the key
|
* @return A new map where the dependents (i.e., Y in "X -> Y") are the key
|
||||||
*/
|
*/
|
||||||
protected abstract Map<T, ?> newTMap();
|
protected abstract Map<T, ?> newTMap();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A new set where the dependents (i.e., Y in "X -> Y") are the key
|
* @return A new set where the dependents (i.e., Y in "X -> Y") are the key
|
||||||
*/
|
*/
|
||||||
protected abstract Set<T> newTSet();
|
protected abstract Set<T> newTSet();
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mark the specified value as satisfied.
|
* Mark the specified value as satisfied.
|
||||||
* For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true)
|
* For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true)
|
||||||
* call, both of these dependencies are considered satisfied.
|
* call, both of these dependencies are considered satisfied.
|
||||||
*
|
*
|
||||||
* @param x Value to mark
|
* @param x Value to mark
|
||||||
|
@ -191,7 +191,7 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)}
|
* Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)}
|
||||||
* or {@link #addOrDependency(Object, Object, Object)}
|
* or {@link #addOrDependency(Object, Object, Object)}
|
||||||
*
|
*
|
||||||
* @param y Dependent to check
|
* @param y Dependent to check
|
||||||
|
@ -207,7 +207,7 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get all dependencies x, for x -> y, and (x1 or x2) -> y
|
* Get all dependencies x, for x -> y, and (x1 or x2) -> y
|
||||||
*
|
*
|
||||||
* @param y Dependent to get dependencies for
|
* @param y Dependent to get dependencies for
|
||||||
* @return List of dependencies
|
* @return List of dependencies
|
||||||
|
@ -223,7 +223,7 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a dependency: y depends on x, as in x -> y
|
* Add a dependency: y depends on x, as in x -> y
|
||||||
*
|
*
|
||||||
* @param y The dependent
|
* @param y The dependent
|
||||||
* @param x The dependee that is required for Y
|
* @param x The dependee that is required for Y
|
||||||
|
@ -302,7 +302,7 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Remove a dependency (x -> y)
|
* Remove a dependency (x -> y)
|
||||||
*
|
*
|
||||||
* @param y The dependent that currently requires X
|
* @param y The dependent that currently requires X
|
||||||
* @param x The dependee that is no longer required for Y
|
* @param x The dependee that is no longer required for Y
|
||||||
|
@ -357,7 +357,7 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y<br>
|
* Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y<br>
|
||||||
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the
|
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the
|
||||||
* dependency is considered satisfied
|
* dependency is considered satisfied
|
||||||
*
|
*
|
||||||
|
@ -382,16 +382,16 @@ public abstract class AbstractDependencyTracker<T, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y)
|
* @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y)
|
||||||
*/
|
*/
|
||||||
public boolean hasNewAllSatisfied() {
|
public boolean hasNewAllSatisfied() {
|
||||||
return !allSatisfiedQueue.isEmpty();
|
return !allSatisfiedQueue.isEmpty();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)}
|
* Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)}
|
||||||
* Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br>
|
* Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br>
|
||||||
* Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value;
|
* Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value;
|
||||||
* the value is considered "processed" at this point.
|
* the value is considered "processed" at this point.
|
||||||
*
|
*
|
||||||
* @return The next new "all satisfied dependent"
|
* @return The next new "all satisfied dependent"
|
||||||
|
|
|
@ -487,7 +487,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add the control dependency from Op -> variable
|
* Add the control dependency from Op -> variable
|
||||||
*
|
*
|
||||||
* @param es Execution step for the variable
|
* @param es Execution step for the variable
|
||||||
* @param v Variable
|
* @param v Variable
|
||||||
|
@ -542,7 +542,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the descendant dependencies
|
* Update the descendant dependencies
|
||||||
* So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker
|
* So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker
|
||||||
* This is for a specific frame and iteration, for both sides of the dependency (in and out)
|
* This is for a specific frame and iteration, for both sides of the dependency (in and out)
|
||||||
*
|
*
|
||||||
* @param justExecuted The execution step that has just completed
|
* @param justExecuted The execution step that has just completed
|
||||||
|
@ -621,7 +621,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Suppose operation X has just been executed.
|
* Suppose operation X has just been executed.
|
||||||
* For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp
|
* For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp
|
||||||
* (which includes X, but may not only be X)
|
* (which includes X, but may not only be X)
|
||||||
*
|
*
|
||||||
* @param opName Name of the op
|
* @param opName Name of the op
|
||||||
|
|
|
@ -28,15 +28,15 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
public class SDBitwise extends SDOps {
|
public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
public SDBitwise(SameDiff sameDiff) {
|
public SDBitwise(SameDiff sameDiff) {
|
||||||
super(sameDiff);
|
super(sameDiff);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise AND operation. Supports broadcasting.<br>
|
* Bitwise AND operation. Supports broadcasting.<br>
|
||||||
*
|
* <p>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
|
||||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param x First input array (INT type)
|
* @param x First input array (INT type)
|
||||||
|
@ -52,9 +52,8 @@ public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise AND operation. Supports broadcasting.<br>
|
* Bitwise AND operation. Supports broadcasting.<br>
|
||||||
*
|
* <p>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
|
||||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
@ -66,12 +65,13 @@ public class SDBitwise extends SDOps {
|
||||||
SDValidation.validateInteger("and", "x", x);
|
SDValidation.validateInteger("and", "x", x);
|
||||||
SDValidation.validateInteger("and", "y", y);
|
SDValidation.validateInteger("and", "y", y);
|
||||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
|
* Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}<br>
|
||||||
*
|
*
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
* @param shift Number of bits to shift. (INT type)
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
@ -80,11 +80,12 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitRotl(SDVariable x, SDVariable shift) {
|
public SDVariable bitRotl(SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitRotl", "x", x);
|
SDValidation.validateInteger("bitRotl", "x", x);
|
||||||
SDValidation.validateInteger("bitRotl", "shift", shift);
|
SDValidation.validateInteger("bitRotl", "shift", shift);
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
|
* Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
|
@ -94,12 +95,13 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) {
|
public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitRotl", "x", x);
|
SDValidation.validateInteger("bitRotl", "x", x);
|
||||||
SDValidation.validateInteger("bitRotl", "shift", shift);
|
SDValidation.validateInteger("bitRotl", "shift", shift);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
|
* Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}<br>
|
||||||
*
|
*
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
* @param shift Number of bits to shift. (INT type)
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
@ -108,11 +110,12 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitRotr(SDVariable x, SDVariable shift) {
|
public SDVariable bitRotr(SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitRotr", "x", x);
|
SDValidation.validateInteger("bitRotr", "x", x);
|
||||||
SDValidation.validateInteger("bitRotr", "shift", shift);
|
SDValidation.validateInteger("bitRotr", "shift", shift);
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
|
* Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
|
@ -122,12 +125,13 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) {
|
public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitRotr", "x", x);
|
SDValidation.validateInteger("bitRotr", "x", x);
|
||||||
SDValidation.validateInteger("bitRotr", "shift", shift);
|
SDValidation.validateInteger("bitRotr", "shift", shift);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shift integer bits to the left, i.e. var << 4<br>
|
* Shift integer bits to the left, i.e. {@code var << 4}<br>
|
||||||
*
|
*
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
* @param shift Number of bits to shift. (INT type)
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
@ -136,11 +140,12 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitShift(SDVariable x, SDVariable shift) {
|
public SDVariable bitShift(SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitShift", "x", x);
|
SDValidation.validateInteger("bitShift", "x", x);
|
||||||
SDValidation.validateInteger("bitShift", "shift", shift);
|
SDValidation.validateInteger("bitShift", "shift", shift);
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shift integer bits to the left, i.e. var << 4<br>
|
* Shift integer bits to the left, i.e. {@code var << 4}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
|
@ -150,12 +155,13 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitShift(String name, SDVariable x, SDVariable shift) {
|
public SDVariable bitShift(String name, SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitShift", "x", x);
|
SDValidation.validateInteger("bitShift", "x", x);
|
||||||
SDValidation.validateInteger("bitShift", "shift", shift);
|
SDValidation.validateInteger("bitShift", "shift", shift);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shift integer bits to the right, i.e. var >> 4<br>
|
* Shift integer bits to the right, i.e. {@code var >> 4}<br>
|
||||||
*
|
*
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
* @param shift Number of bits to shift. (INT type)
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
@ -164,11 +170,12 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitShiftRight(SDVariable x, SDVariable shift) {
|
public SDVariable bitShiftRight(SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitShiftRight", "x", x);
|
SDValidation.validateInteger("bitShiftRight", "x", x);
|
||||||
SDValidation.validateInteger("bitShiftRight", "shift", shift);
|
SDValidation.validateInteger("bitShiftRight", "shift", shift);
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shift integer bits to the right, i.e. var >> 4<br>
|
* Shift integer bits to the right, i.e. {@code var >> 4}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x Input 1 (INT type)
|
* @param x Input 1 (INT type)
|
||||||
|
@ -178,16 +185,17 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) {
|
public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) {
|
||||||
SDValidation.validateInteger("bitShiftRight", "x", x);
|
SDValidation.validateInteger("bitShiftRight", "x", x);
|
||||||
SDValidation.validateInteger("bitShiftRight", "shift", shift);
|
SDValidation.validateInteger("bitShiftRight", "shift", shift);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x,
|
||||||
|
shift).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
* Bitwise Hamming distance reduction over all elements of both input arrays.<br> For example, if
|
||||||
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
|
* x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at
|
||||||
*
|
* positions 0 and 1)<br>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* <p>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param x First input array. (INT type)
|
* @param x First input array. (INT type)
|
||||||
* @param y Second input array. (INT type)
|
* @param y Second input array. (INT type)
|
||||||
|
@ -197,15 +205,16 @@ public class SDBitwise extends SDOps {
|
||||||
SDValidation.validateInteger("bitsHammingDistance", "x", x);
|
SDValidation.validateInteger("bitsHammingDistance", "x", x);
|
||||||
SDValidation.validateInteger("bitsHammingDistance", "y", y);
|
SDValidation.validateInteger("bitsHammingDistance", "y", y);
|
||||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
* Bitwise Hamming distance reduction over all elements of both input arrays.<br> For example, if
|
||||||
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
|
* x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at
|
||||||
*
|
* positions 0 and 1)<br>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* <p>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x First input array. (INT type)
|
* @param x First input array. (INT type)
|
||||||
|
@ -216,7 +225,8 @@ public class SDBitwise extends SDOps {
|
||||||
SDValidation.validateInteger("bitsHammingDistance", "x", x);
|
SDValidation.validateInteger("bitsHammingDistance", "x", x);
|
||||||
SDValidation.validateInteger("bitsHammingDistance", "y", y);
|
SDValidation.validateInteger("bitsHammingDistance", "y", y);
|
||||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,13 +254,14 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable leftShift(String name, SDVariable x, SDVariable y) {
|
public SDVariable leftShift(String name, SDVariable x, SDVariable y) {
|
||||||
SDValidation.validateInteger("leftShift", "x", x);
|
SDValidation.validateInteger("leftShift", "x", x);
|
||||||
SDValidation.validateInteger("leftShift", "y", y);
|
SDValidation.validateInteger("leftShift", "y", y);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise left cyclical shift operation. Supports broadcasting.<br> Unlike
|
||||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* {@link SDBitwise#leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||||
*
|
*
|
||||||
* @param x Input to be bit shifted (INT type)
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
@ -260,12 +271,13 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) {
|
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) {
|
||||||
SDValidation.validateInteger("leftShiftCyclic", "x", x);
|
SDValidation.validateInteger("leftShiftCyclic", "x", x);
|
||||||
SDValidation.validateInteger("leftShiftCyclic", "y", y);
|
SDValidation.validateInteger("leftShiftCyclic", "y", y);
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise left cyclical shift operation. Supports broadcasting.<br> Unlike
|
||||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
@ -276,15 +288,15 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) {
|
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) {
|
||||||
SDValidation.validateInteger("leftShiftCyclic", "x", x);
|
SDValidation.validateInteger("leftShiftCyclic", "x", x);
|
||||||
SDValidation.validateInteger("leftShiftCyclic", "y", y);
|
SDValidation.validateInteger("leftShiftCyclic", "y", y);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise OR operation. Supports broadcasting.<br>
|
* Bitwise OR operation. Supports broadcasting.<br>
|
||||||
*
|
* <p>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
|
||||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param x First input array (INT type)
|
* @param x First input array (INT type)
|
||||||
|
@ -300,9 +312,8 @@ public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise OR operation. Supports broadcasting.<br>
|
* Bitwise OR operation. Supports broadcasting.<br>
|
||||||
*
|
* <p>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
|
||||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
@ -314,7 +325,8 @@ public class SDBitwise extends SDOps {
|
||||||
SDValidation.validateInteger("or", "x", x);
|
SDValidation.validateInteger("or", "x", x);
|
||||||
SDValidation.validateInteger("or", "y", y);
|
SDValidation.validateInteger("or", "y", y);
|
||||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,13 +354,14 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable rightShift(String name, SDVariable x, SDVariable y) {
|
public SDVariable rightShift(String name, SDVariable x, SDVariable y) {
|
||||||
SDValidation.validateInteger("rightShift", "x", x);
|
SDValidation.validateInteger("rightShift", "x", x);
|
||||||
SDValidation.validateInteger("rightShift", "y", y);
|
SDValidation.validateInteger("rightShift", "y", y);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise right cyclical shift operation. Supports broadcasting.<br> Unlike
|
||||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||||
*
|
*
|
||||||
* @param x Input to be bit shifted (INT type)
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
@ -358,12 +371,13 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) {
|
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) {
|
||||||
SDValidation.validateInteger("rightShiftCyclic", "x", x);
|
SDValidation.validateInteger("rightShiftCyclic", "x", x);
|
||||||
SDValidation.validateInteger("rightShiftCyclic", "y", y);
|
SDValidation.validateInteger("rightShiftCyclic", "y", y);
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise right cyclical shift operation. Supports broadcasting.<br> Unlike
|
||||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
@ -374,15 +388,15 @@ public class SDBitwise extends SDOps {
|
||||||
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) {
|
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) {
|
||||||
SDValidation.validateInteger("rightShiftCyclic", "x", x);
|
SDValidation.validateInteger("rightShiftCyclic", "x", x);
|
||||||
SDValidation.validateInteger("rightShiftCyclic", "y", y);
|
SDValidation.validateInteger("rightShiftCyclic", "y", y);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
|
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
|
||||||
*
|
* <p>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
|
||||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param x First input array (INT type)
|
* @param x First input array (INT type)
|
||||||
|
@ -398,9 +412,8 @@ public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
|
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
|
||||||
*
|
* <p>
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br> Must be same types: isSameType(x, y)<br>
|
||||||
* Must be same types: isSameType(x, y)<br>
|
|
||||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
@ -412,7 +425,8 @@ public class SDBitwise extends SDOps {
|
||||||
SDValidation.validateInteger("xor", "x", x);
|
SDValidation.validateInteger("xor", "x", x);
|
||||||
SDValidation.validateInteger("xor", "y", y);
|
SDValidation.validateInteger("xor", "y", y);
|
||||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd, x,
|
||||||
|
y).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,7 +355,8 @@ public class SDImage extends SDOps {
|
||||||
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
||||||
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
|
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
|
||||||
* @param scoreThreshold threshold for deciding when to remove boxes based on score
|
* @param scoreThreshold threshold for deciding when to remove boxes based on score
|
||||||
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type)
|
* @return output vectort of shape [M] representing the selected indices from the boxes tensor,
|
||||||
|
* where M <= max_output_size (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize,
|
public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize,
|
||||||
double iouThreshold, double scoreThreshold) {
|
double iouThreshold, double scoreThreshold) {
|
||||||
|
@ -373,7 +374,8 @@ public class SDImage extends SDOps {
|
||||||
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
||||||
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
|
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
|
||||||
* @param scoreThreshold threshold for deciding when to remove boxes based on score
|
* @param scoreThreshold threshold for deciding when to remove boxes based on score
|
||||||
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type)
|
* @return output vectort of shape [M] representing the selected indices from the boxes tensor,
|
||||||
|
* where M <= max_output_size (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores,
|
public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores,
|
||||||
int maxOutSize, double iouThreshold, double scoreThreshold) {
|
int maxOutSize, double iouThreshold, double scoreThreshold) {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -100,7 +100,7 @@ public class SDRandom extends SDOps {
|
||||||
* P(x) = lambda * exp(-lambda * x)<br>
|
* P(x) = lambda * exp(-lambda * x)<br>
|
||||||
*
|
*
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br>
|
||||||
* Must be positive: lambda > 0<br>
|
* Must be positive: lambda > 0<br>
|
||||||
*
|
*
|
||||||
* @param lambda lambda parameter
|
* @param lambda lambda parameter
|
||||||
* @param datatype Data type of the output variable
|
* @param datatype Data type of the output variable
|
||||||
|
@ -118,7 +118,7 @@ public class SDRandom extends SDOps {
|
||||||
* P(x) = lambda * exp(-lambda * x)<br>
|
* P(x) = lambda * exp(-lambda * x)<br>
|
||||||
*
|
*
|
||||||
* Inputs must satisfy the following constraints: <br>
|
* Inputs must satisfy the following constraints: <br>
|
||||||
* Must be positive: lambda > 0<br>
|
* Must be positive: lambda > 0<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param lambda lambda parameter
|
* @param lambda lambda parameter
|
||||||
|
|
|
@ -829,9 +829,9 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
|
||||||
* Precision based on guesses so far.<br>
|
* Precision based on guesses so far.<br>
|
||||||
* Note: value returned will differ depending on number of classes and settings.<br>
|
* Note: value returned will differ depending on number of classes and settings.<br>
|
||||||
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
||||||
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
* or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
||||||
* only.<br>
|
* only.<br>
|
||||||
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
* 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
||||||
* across all classes. i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}<br>
|
* across all classes. i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}<br>
|
||||||
*
|
*
|
||||||
* @return the total precision based on guesses so far
|
* @return the total precision based on guesses so far
|
||||||
|
@ -977,7 +977,7 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
|
||||||
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
||||||
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
||||||
* only.<br>
|
* only.<br>
|
||||||
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
* 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
||||||
* across all classes. i.e., is macro-averaged recall, equivalent to {@code recall(EvaluationAveraging.Macro)}<br>
|
* across all classes. i.e., is macro-averaged recall, equivalent to {@code recall(EvaluationAveraging.Macro)}<br>
|
||||||
*
|
*
|
||||||
* @return the recall for the outcomes
|
* @return the recall for the outcomes
|
||||||
|
@ -1173,12 +1173,12 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* False Alarm Rate (FAR) reflects rate of misclassified to classified records
|
* False Alarm Rate (FAR) reflects rate of misclassified to classified records
|
||||||
* {@link }http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}<br>
|
* {@see http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}<br>
|
||||||
* Note: value returned will differ depending on number of classes and settings.<br>
|
* Note: value returned will differ depending on number of classes and settings.<br>
|
||||||
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
||||||
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
* or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
||||||
* only.<br>
|
* only.<br>
|
||||||
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
* 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
||||||
* across all classes. i.e., is macro-averaged false alarm rate)
|
* across all classes. i.e., is macro-averaged false alarm rate)
|
||||||
*
|
*
|
||||||
* @return the fpr for the outcomes
|
* @return the fpr for the outcomes
|
||||||
|
@ -1243,9 +1243,9 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
|
||||||
* <br>
|
* <br>
|
||||||
* Note: value returned will differ depending on number of classes and settings.<br>
|
* Note: value returned will differ depending on number of classes and settings.<br>
|
||||||
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor,
|
||||||
* or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
* or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class
|
||||||
* only.<br>
|
* only.<br>
|
||||||
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
* 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged
|
||||||
* across all classes. i.e., is macro-averaged f1, equivalent to {@code f1(EvaluationAveraging.Macro)}<br>
|
* across all classes. i.e., is macro-averaged f1, equivalent to {@code f1(EvaluationAveraging.Macro)}<br>
|
||||||
*
|
*
|
||||||
* @return the f1 score or harmonic mean of precision and recall based on current guesses
|
* @return the f1 score or harmonic mean of precision and recall based on current guesses
|
||||||
|
|
|
@ -584,7 +584,7 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* False Alarm Rate (FAR) reflects rate of misclassified to classified records
|
* False Alarm Rate (FAR) reflects rate of misclassified to classified records
|
||||||
* <a href="http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw">http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw</a><br>
|
* http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw<br>
|
||||||
*
|
*
|
||||||
* @param outputNum Class index to calculate False Alarm Rate (FAR)
|
* @param outputNum Class index to calculate False Alarm Rate (FAR)
|
||||||
* @return The FAR for the outcomes
|
* @return The FAR for the outcomes
|
||||||
|
@ -611,7 +611,7 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
||||||
|
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
|
|
||||||
//Report: Accuracy, precision, recall, F1. Then: confusion matrix
|
//Report: Accuracy, precision, recall, F1. Then: confusion matrix]
|
||||||
|
|
||||||
int maxLabelsLength = 15;
|
int maxLabelsLength = 15;
|
||||||
if (labels != null) {
|
if (labels != null) {
|
||||||
|
|
|
@ -202,7 +202,7 @@ public abstract class BaseLapack implements Lapack {
|
||||||
*
|
*
|
||||||
* @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors
|
* @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors
|
||||||
* @param uplo upper or lower part of symmetric matrix to use
|
* @param uplo upper or lower part of symmetric matrix to use
|
||||||
* @param N the number of rows & cols in the matrix A
|
* @param N the number of rows & cols in the matrix A
|
||||||
* @param A the matrix to calculate eigenvectors
|
* @param A the matrix to calculate eigenvectors
|
||||||
* @param R an output array for eigenvalues ( may be null )
|
* @param R an output array for eigenvalues ( may be null )
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -74,14 +74,14 @@ public interface DataBufferFactory {
|
||||||
DataBuffer createDouble(long offset, int length);
|
DataBuffer createDouble(long offset, int length);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method will create new DataBuffer of the same dataType & same length
|
* This method will create new DataBuffer of the same dataType & same length
|
||||||
* @param buffer
|
* @param buffer
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
DataBuffer createSame(DataBuffer buffer, boolean init);
|
DataBuffer createSame(DataBuffer buffer, boolean init);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method will create new DataBuffer of the same dataType & same length
|
* This method will create new DataBuffer of the same dataType & same length
|
||||||
* @param buffer
|
* @param buffer
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -132,7 +132,6 @@ public interface MemoryManager {
|
||||||
*
|
*
|
||||||
* @param pointer
|
* @param pointer
|
||||||
* @param kind
|
* @param kind
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
void release(Pointer pointer, MemoryKind kind);
|
void release(Pointer pointer, MemoryKind kind);
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ public interface MemoryWorkspaceManager {
|
||||||
void destroyWorkspace(MemoryWorkspace workspace);
|
void destroyWorkspace(MemoryWorkspace workspace);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method destroys & deallocates all Workspaces for a calling Thread
|
* This method destroys & deallocates all Workspaces for a calling Thread
|
||||||
*
|
*
|
||||||
* PLEASE NOTE: This method is NOT safe
|
* PLEASE NOTE: This method is NOT safe
|
||||||
*/
|
*/
|
||||||
|
@ -149,21 +149,21 @@ public interface MemoryWorkspaceManager {
|
||||||
void destroyWorkspace();
|
void destroyWorkspace();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method gets & activates default workspace
|
* This method gets and activates default workspace
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
MemoryWorkspace getAndActivateWorkspace();
|
MemoryWorkspace getAndActivateWorkspace();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method gets & activates workspace with a given Id
|
* This method gets and activates workspace with a given Id
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
MemoryWorkspace getAndActivateWorkspace(String id);
|
MemoryWorkspace getAndActivateWorkspace(String id);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method gets & activates default with a given configuration and Id
|
* This method gets and activates default with a given configuration and Id
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -47,7 +47,7 @@ public abstract class BaseShapeInfoProvider implements ShapeInfoProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method creates shapeInformation buffer, based on shape & order being passed in
|
* This method creates shapeInformation buffer, based on shape and order being passed in
|
||||||
*
|
*
|
||||||
* @param shape
|
* @param shape
|
||||||
* @param order
|
* @param order
|
||||||
|
|
|
@ -2216,7 +2216,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
* Dimshuffle: an extension of permute that adds the ability
|
* Dimshuffle: an extension of permute that adds the ability
|
||||||
* to broadcast various dimensions.
|
* to broadcast various dimensions.
|
||||||
* This will only accept integers and xs.
|
* This will only accept integers and xs.
|
||||||
* <p/>
|
* <p>
|
||||||
* An x indicates a dimension should be broadcasted rather than permuted.
|
* An x indicates a dimension should be broadcasted rather than permuted.
|
||||||
*
|
*
|
||||||
* Examples originally from the theano docs:
|
* Examples originally from the theano docs:
|
||||||
|
@ -2226,15 +2226,15 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
|
|
||||||
A few examples of patterns and their effect:
|
A few examples of patterns and their effect:
|
||||||
|
|
||||||
('x') -> make a 0d (scalar) into a 1d vector
|
('x') -> make a 0d (scalar) into a 1d vector
|
||||||
(0, 1) -> identity for 2d vectors
|
(0, 1) -> identity for 2d vectors
|
||||||
(1, 0) -> inverts the first and second dimensions
|
(1, 0) -> inverts the first and second dimensions
|
||||||
('x', 0) -> make a row out of a 1d vector (N to 1xN)
|
('x', 0) -> make a row out of a 1d vector (N to 1xN)
|
||||||
(0, 'x') -> make a column out of a 1d vector (N to Nx1)
|
(0, 'x') -> make a column out of a 1d vector (N to Nx1)
|
||||||
(2, 0, 1) -> AxBxC to CxAxB
|
(2, 0, 1) -> AxBxC to CxAxB
|
||||||
(0, 'x', 1) -> AxB to Ax1xB
|
(0, 'x', 1) -> AxB to Ax1xB
|
||||||
(1, 'x', 0) -> AxB to Bx1xA
|
(1, 'x', 0) -> AxB to Bx1xA
|
||||||
(1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
|
(1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
|
||||||
|
|
||||||
* @param rearrange the dimensions to swap to
|
* @param rearrange the dimensions to swap to
|
||||||
* @param newOrder the new order (think permute)
|
* @param newOrder the new order (think permute)
|
||||||
|
@ -2244,7 +2244,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable);
|
INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #dimShuffle(Object[], int[], boolean[])
|
* See {@link #dimShuffle(Object[], int[], boolean[])}
|
||||||
*/
|
*/
|
||||||
INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable);
|
INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable);
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ public interface ShapeInfoProvider {
|
||||||
Pair<DataBuffer, long[]> createShapeInformation(long[] shape, DataType dataType);
|
Pair<DataBuffer, long[]> createShapeInformation(long[] shape, DataType dataType);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method creates long shapeInformation buffer, based on shape & order being passed in
|
* This method creates long shapeInformation buffer, based on shape and order being passed in
|
||||||
* @param shape
|
* @param shape
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -65,7 +65,8 @@ public interface OpContext extends AutoCloseable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets root-level seed for rng
|
* This method sets root-level seed for rng
|
||||||
* @param seed
|
* @param rootState
|
||||||
|
* @param nodeState
|
||||||
*/
|
*/
|
||||||
void setRngStates(long rootState, long nodeState);
|
void setRngStates(long rootState, long nodeState);
|
||||||
|
|
||||||
|
|
|
@ -251,7 +251,7 @@ public interface Random extends AutoCloseable {
|
||||||
* The reason for this is due to ints
|
* The reason for this is due to ints
|
||||||
* having the same space usage as floats.
|
* having the same space usage as floats.
|
||||||
* This also plays nice with blas.
|
* This also plays nice with blas.
|
||||||
* <p/>
|
* <p>
|
||||||
* If the data opType is set to double,
|
* If the data opType is set to double,
|
||||||
* then these will be whole doubles.
|
* then these will be whole doubles.
|
||||||
*
|
*
|
||||||
|
@ -272,7 +272,7 @@ public interface Random extends AutoCloseable {
|
||||||
* The reason for this is due to ints
|
* The reason for this is due to ints
|
||||||
* having the same space usage as floats.
|
* having the same space usage as floats.
|
||||||
* This also plays nice with blas.
|
* This also plays nice with blas.
|
||||||
* <p/>
|
* <p>
|
||||||
* If the data opType is set to double,
|
* If the data opType is set to double,
|
||||||
* then these will be whole doubles.
|
* then these will be whole doubles.
|
||||||
*
|
*
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
|
||||||
public abstract class BaseDistribution implements Distribution {
|
public abstract class BaseDistribution implements Distribution {
|
||||||
|
|
||||||
protected Random random;
|
protected Random random;
|
||||||
protected double solverAbsoluteAccuracy;
|
protected double solverAbsoluteAccuracy;
|
||||||
|
|
||||||
|
@ -49,31 +50,33 @@ public abstract class BaseDistribution implements Distribution {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For a random variable {@code X} whose values are distributed according
|
* For a random variable {@code X} whose values are distributed according to this distribution,
|
||||||
* to this distribution, this method returns {@code P(x0 < X <= x1)}.
|
* this method returns {@code P(x0 < X <= x1)}.
|
||||||
*
|
*
|
||||||
* @param x0 Lower bound (excluded).
|
* @param x0 Lower bound (excluded).
|
||||||
* @param x1 Upper bound (included).
|
* @param x1 Upper bound (included).
|
||||||
* @return the probability that a random variable with this distribution
|
* @return the probability that a random variable with this distribution takes a value between
|
||||||
* takes a value between {@code x0} and {@code x1}, excluding the lower
|
* {@code x0} and {@code x1}, excluding the lower and including the upper endpoint.
|
||||||
* and including the upper endpoint.
|
|
||||||
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}.
|
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}.
|
||||||
* <p/>
|
* <p>
|
||||||
* The default implementation uses the identity
|
* The default implementation
|
||||||
* {@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)}
|
* uses the identity
|
||||||
|
* {@code P(x0 < X <= x1) =
|
||||||
|
* P(X <= x1) - P(X <= x0)}
|
||||||
* @since 3.1
|
* @since 3.1
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public double probability(double x0, double x1) {
|
public double probability(double x0, double x1) {
|
||||||
if (x0 > x1) {
|
if (x0 > x1) {
|
||||||
throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, x1, true);
|
throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0,
|
||||||
|
x1, true);
|
||||||
}
|
}
|
||||||
return cumulativeProbability(x1) - cumulativeProbability(x0);
|
return cumulativeProbability(x1) - cumulativeProbability(x0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The default implementation returns
|
* The default implementation returns
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
|
* <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
|
||||||
|
@ -127,7 +130,8 @@ public abstract class BaseDistribution implements Distribution {
|
||||||
final double mu = getNumericalMean();
|
final double mu = getNumericalMean();
|
||||||
final double sig = FastMath.sqrt(getNumericalVariance());
|
final double sig = FastMath.sqrt(getNumericalVariance());
|
||||||
final boolean chebyshevApplies;
|
final boolean chebyshevApplies;
|
||||||
chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig) || Double.isNaN(sig));
|
chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig)
|
||||||
|
|| Double.isNaN(sig));
|
||||||
|
|
||||||
if (lowerBound == Double.NEGATIVE_INFINITY) {
|
if (lowerBound == Double.NEGATIVE_INFINITY) {
|
||||||
if (chebyshevApplies) {
|
if (chebyshevApplies) {
|
||||||
|
@ -158,7 +162,8 @@ public abstract class BaseDistribution implements Distribution {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound, getSolverAbsoluteAccuracy());
|
double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound,
|
||||||
|
getSolverAbsoluteAccuracy());
|
||||||
|
|
||||||
if (!isSupportConnected()) {
|
if (!isSupportConnected()) {
|
||||||
/* Test for plateau. */
|
/* Test for plateau. */
|
||||||
|
@ -183,9 +188,8 @@ public abstract class BaseDistribution implements Distribution {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the solver absolute accuracy for inverse cumulative computation.
|
* Returns the solver absolute accuracy for inverse cumulative computation. You can override this
|
||||||
* You can override this method in order to use a Brent solver with an
|
* method in order to use a Brent solver with an absolute accuracy different from the default.
|
||||||
* absolute accuracy different from the default.
|
|
||||||
*
|
*
|
||||||
* @return the maximum absolute error in inverse cumulative probability estimates
|
* @return the maximum absolute error in inverse cumulative probability estimates
|
||||||
*/
|
*/
|
||||||
|
@ -203,7 +207,6 @@ public abstract class BaseDistribution implements Distribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* The default implementation uses the
|
* The default implementation uses the
|
||||||
* <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling">
|
* <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling">
|
||||||
* inversion method.
|
* inversion method.
|
||||||
|
@ -216,9 +219,8 @@ public abstract class BaseDistribution implements Distribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The default implementation generates the sample by calling
|
* The default implementation generates the sample by calling {@link #sample()} in a loop.
|
||||||
* {@link #sample()} in a loop.
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public double[] sample(long sampleSize) {
|
public double[] sample(long sampleSize) {
|
||||||
|
@ -257,7 +259,8 @@ public abstract class BaseDistribution implements Distribution {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray sample(INDArray target) {
|
public INDArray sample(INDArray target) {
|
||||||
Iterator<long[]> idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering
|
Iterator<long[]> idxIter = new NdIndexIterator(
|
||||||
|
target.shape()); //For consistent values irrespective of c vs. fortran ordering
|
||||||
long len = target.length();
|
long len = target.length();
|
||||||
for (long i = 0; i < len; i++) {
|
for (long i = 0; i < len; i++) {
|
||||||
target.putScalar(idxIter.next(), sample());
|
target.putScalar(idxIter.next(), sample());
|
||||||
|
|
|
@ -89,8 +89,8 @@ public interface Distribution {
|
||||||
* variable {@code X} distributed according to this distribution, the
|
* variable {@code X} distributed according to this distribution, the
|
||||||
* returned value is
|
* returned value is
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li><code>inf{x in R | P(X<=x) >= p}</code> for {@code 0 < p <= 1},</li>
|
* <li>{@code inf{x in R | P(X<=x) >= p}} for {@code 0 < p <= 1},</li>
|
||||||
* <li><code>inf{x in R | P(X<=x) > 0}</code> for {@code p = 0}.</li>
|
* <li>{@code inf{x in R | P(X<=x) > 0}} for {@code p = 0}.</li>
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param p the cumulative probability
|
* @param p the cumulative probability
|
||||||
|
@ -122,7 +122,7 @@ public interface Distribution {
|
||||||
* Access the lower bound of the support. This method must return the same
|
* Access the lower bound of the support. This method must return the same
|
||||||
* value as {@code inverseCumulativeProbability(0)}. In other words, this
|
* value as {@code inverseCumulativeProbability(0)}. In other words, this
|
||||||
* method must return
|
* method must return
|
||||||
* <p><code>inf {x in R | P(X <= x) > 0}</code>.</p>
|
* <p>{@code inf {x in R | P(X <= x) > 0}}.</p>
|
||||||
*
|
*
|
||||||
* @return lower bound of the support (might be
|
* @return lower bound of the support (might be
|
||||||
* {@code Double.NEGATIVE_INFINITY})
|
* {@code Double.NEGATIVE_INFINITY})
|
||||||
|
@ -133,7 +133,7 @@ public interface Distribution {
|
||||||
* Access the upper bound of the support. This method must return the same
|
* Access the upper bound of the support. This method must return the same
|
||||||
* value as {@code inverseCumulativeProbability(1)}. In other words, this
|
* value as {@code inverseCumulativeProbability(1)}. In other words, this
|
||||||
* method must return
|
* method must return
|
||||||
* <p><code>inf {x in R | P(X <= x) = 1}</code>.</p>
|
* <p>{@code inf {x in R | P(X <= x) = 1}}.</p>
|
||||||
*
|
*
|
||||||
* @return upper bound of the support (might be
|
* @return upper bound of the support (might be
|
||||||
* {@code Double.POSITIVE_INFINITY})
|
* {@code Double.POSITIVE_INFINITY})
|
||||||
|
|
|
@ -166,7 +166,7 @@ public class BinomialDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For {@code n} trials and probability parameter {@code p}, the mean is
|
* For {@code n} trials and probability parameter {@code p}, the mean is
|
||||||
* {@code n * p}.
|
* {@code n * p}.
|
||||||
*/
|
*/
|
||||||
|
@ -177,7 +177,7 @@ public class BinomialDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For {@code n} trials and probability parameter {@code p}, the variance is
|
* For {@code n} trials and probability parameter {@code p}, the variance is
|
||||||
* {@code n * p * (1 - p)}.
|
* {@code n * p * (1 - p)}.
|
||||||
*/
|
*/
|
||||||
|
@ -189,7 +189,7 @@ public class BinomialDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The lower bound of the support is always 0 except for the probability
|
* The lower bound of the support is always 0 except for the probability
|
||||||
* parameter {@code p = 1}.
|
* parameter {@code p = 1}.
|
||||||
*
|
*
|
||||||
|
@ -203,7 +203,7 @@ public class BinomialDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The upper bound of the support is the number of trials except for the
|
* The upper bound of the support is the number of trials except for the
|
||||||
* probability parameter {@code p = 0}.
|
* probability parameter {@code p = 0}.
|
||||||
*
|
*
|
||||||
|
@ -227,7 +227,7 @@ public class BinomialDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The support of this distribution is connected.
|
* The support of this distribution is connected.
|
||||||
*
|
*
|
||||||
* @return {@code true}
|
* @return {@code true}
|
||||||
|
|
|
@ -83,7 +83,7 @@ public class ConstantDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
||||||
* is returned, as in these cases the actual value is within
|
* is returned, as in these cases the actual value is within
|
||||||
* {@code Double.MIN_VALUE} of 0 or 1.
|
* {@code Double.MIN_VALUE} of 0 or 1.
|
||||||
|
@ -131,7 +131,7 @@ public class ConstantDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For mean parameter {@code mu}, the mean is {@code mu}.
|
* For mean parameter {@code mu}, the mean is {@code mu}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalMean() {
|
public double getNumericalMean() {
|
||||||
|
@ -140,7 +140,7 @@ public class ConstantDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalVariance() {
|
public double getNumericalVariance() {
|
||||||
|
@ -150,7 +150,7 @@ public class ConstantDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The lower bound of the support is always negative infinity
|
* The lower bound of the support is always negative infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -163,7 +163,7 @@ public class ConstantDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The upper bound of the support is always positive infinity
|
* The upper bound of the support is always positive infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -190,7 +190,7 @@ public class ConstantDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The support of this distribution is connected.
|
* The support of this distribution is connected.
|
||||||
*
|
*
|
||||||
* @return {@code true}
|
* @return {@code true}
|
||||||
|
|
|
@ -172,7 +172,6 @@ public class LogNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
||||||
* is returned, as in these cases the actual value is within
|
* is returned, as in these cases the actual value is within
|
||||||
* {@code Double.MIN_VALUE} of 0 or 1.
|
* {@code Double.MIN_VALUE} of 0 or 1.
|
||||||
|
@ -238,7 +237,6 @@ public class LogNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* For mean parameter {@code mu}, the mean is {@code mu}.
|
* For mean parameter {@code mu}, the mean is {@code mu}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalMean() {
|
public double getNumericalMean() {
|
||||||
|
@ -247,7 +245,6 @@ public class LogNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalVariance() {
|
public double getNumericalVariance() {
|
||||||
|
@ -257,7 +254,6 @@ public class LogNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* The lower bound of the support is always negative infinity
|
* The lower bound of the support is always negative infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -270,7 +266,7 @@ public class LogNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The upper bound of the support is always positive infinity
|
* The upper bound of the support is always positive infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -297,7 +293,7 @@ public class LogNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The support of this distribution is connected.
|
* The support of this distribution is connected.
|
||||||
*
|
*
|
||||||
* @return {@code true}
|
* @return {@code true}
|
||||||
|
|
|
@ -176,7 +176,6 @@ public class NormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
||||||
* is returned, as in these cases the actual value is within
|
* is returned, as in these cases the actual value is within
|
||||||
* {@code Double.MIN_VALUE} of 0 or 1.
|
* {@code Double.MIN_VALUE} of 0 or 1.
|
||||||
|
@ -242,7 +241,6 @@ public class NormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* For mean parameter {@code mu}, the mean is {@code mu}.
|
* For mean parameter {@code mu}, the mean is {@code mu}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalMean() {
|
public double getNumericalMean() {
|
||||||
|
@ -251,7 +249,6 @@ public class NormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalVariance() {
|
public double getNumericalVariance() {
|
||||||
|
@ -261,7 +258,6 @@ public class NormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* The lower bound of the support is always negative infinity
|
* The lower bound of the support is always negative infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -274,7 +270,6 @@ public class NormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* The upper bound of the support is always positive infinity
|
* The upper bound of the support is always positive infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -301,7 +296,6 @@ public class NormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
|
||||||
* The support of this distribution is connected.
|
* The support of this distribution is connected.
|
||||||
*
|
*
|
||||||
* @return {@code true}
|
* @return {@code true}
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OrthogonalDistribution extends BaseDistribution {
|
public class OrthogonalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default inverse cumulative probability accuracy.
|
* Default inverse cumulative probability accuracy.
|
||||||
*
|
*
|
||||||
|
@ -62,6 +63,7 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
this.random = Nd4j.getRandom();
|
this.random = Nd4j.getRandom();
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Access the mean.
|
* Access the mean.
|
||||||
*
|
*
|
||||||
|
@ -88,11 +90,8 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc} If {@code x} is more than 40 standard deviations from the mean, 0 or 1 is
|
||||||
* <p/>
|
* returned, as in these cases the actual value is within {@code Double.MIN_VALUE} of 0 or 1.
|
||||||
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
|
||||||
* is returned, as in these cases the actual value is within
|
|
||||||
* {@code Double.MIN_VALUE} of 0 or 1.
|
|
||||||
*/
|
*/
|
||||||
public double cumulativeProbability(double x) {
|
public double cumulativeProbability(double x) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
@ -111,7 +110,9 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*
|
*
|
||||||
* @deprecated See {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double, double)}
|
* @deprecated See
|
||||||
|
* {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double,
|
||||||
|
* double)}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
@ -136,18 +137,14 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc} For mean parameter {@code mu}, the mean is {@code mu}.
|
||||||
* <p/>
|
|
||||||
* For mean parameter {@code mu}, the mean is {@code mu}.
|
|
||||||
*/
|
*/
|
||||||
public double getNumericalMean() {
|
public double getNumericalMean() {
|
||||||
return getMean();
|
return getMean();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc} For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
||||||
* <p/>
|
|
||||||
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
|
||||||
*/
|
*/
|
||||||
public double getNumericalVariance() {
|
public double getNumericalVariance() {
|
||||||
final double s = getStandardDeviation();
|
final double s = getStandardDeviation();
|
||||||
|
@ -155,13 +152,10 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc} The lower bound of the support is always negative infinity no matter the
|
||||||
* <p/>
|
* parameters.
|
||||||
* The lower bound of the support is always negative infinity
|
|
||||||
* no matter the parameters.
|
|
||||||
*
|
*
|
||||||
* @return lower bound of the support (always
|
* @return lower bound of the support (always {@code Double.NEGATIVE_INFINITY})
|
||||||
* {@code Double.NEGATIVE_INFINITY})
|
|
||||||
*/
|
*/
|
||||||
public double getSupportLowerBound() {
|
public double getSupportLowerBound() {
|
||||||
return Double.NEGATIVE_INFINITY;
|
return Double.NEGATIVE_INFINITY;
|
||||||
|
@ -169,12 +163,10 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The upper bound of the support is always positive infinity
|
* The upper bound of the support is always positive infinity no matter the parameters.
|
||||||
* no matter the parameters.
|
|
||||||
*
|
*
|
||||||
* @return upper bound of the support (always
|
* @return upper bound of the support (always {@code Double.POSITIVE_INFINITY})
|
||||||
* {@code Double.POSITIVE_INFINITY})
|
|
||||||
*/
|
*/
|
||||||
public double getSupportUpperBound() {
|
public double getSupportUpperBound() {
|
||||||
return Double.POSITIVE_INFINITY;
|
return Double.POSITIVE_INFINITY;
|
||||||
|
@ -196,7 +188,7 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The support of this distribution is connected.
|
* The support of this distribution is connected.
|
||||||
*
|
*
|
||||||
* @return {@code true}
|
* @return {@code true}
|
||||||
|
@ -221,14 +213,17 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
@Override
|
@Override
|
||||||
public INDArray sample(long[] shape) {
|
public INDArray sample(long[] shape) {
|
||||||
long numRows = 1;
|
long numRows = 1;
|
||||||
for (int i = 0; i < shape.length - 1; i++)
|
for (int i = 0; i < shape.length - 1; i++) {
|
||||||
numRows *= shape[i];
|
numRows *= shape[i];
|
||||||
|
}
|
||||||
long numCols = shape[shape.length - 1];
|
long numCols = shape[shape.length - 1];
|
||||||
|
|
||||||
val dtype = Nd4j.defaultFloatingPointType();
|
val dtype = Nd4j.defaultFloatingPointType();
|
||||||
|
|
||||||
val flatShape = new long[]{numRows, numCols};
|
val flatShape = new long[]{numRows, numCols};
|
||||||
val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random);
|
val flatRng = Nd4j.getExecutioner().exec(
|
||||||
|
new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0,
|
||||||
|
1.0), random);
|
||||||
|
|
||||||
val m = flatRng.rows();
|
val m = flatRng.rows();
|
||||||
val n = flatRng.columns();
|
val n = flatRng.columns();
|
||||||
|
@ -241,9 +236,11 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
if (gains == null) {
|
if (gains == null) {
|
||||||
if (u.rows() >= numRows && u.columns() >= numCols) {
|
if (u.rows() >= numRows && u.columns() >= numCols) {
|
||||||
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
|
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain)
|
||||||
|
.reshape(shape);
|
||||||
} else {
|
} else {
|
||||||
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
|
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain)
|
||||||
|
.reshape(shape);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
|
|
@ -84,7 +84,6 @@ public class SaddlePointExpansion {
|
||||||
* href="http://mathworld.wolfram.com/StirlingsSeries.html">
|
* href="http://mathworld.wolfram.com/StirlingsSeries.html">
|
||||||
* http://mathworld.wolfram.com/StirlingsSeries.html</a></li>
|
* http://mathworld.wolfram.com/StirlingsSeries.html</a></li>
|
||||||
* </ol>
|
* </ol>
|
||||||
* </p>
|
|
||||||
*
|
*
|
||||||
* @param z the value.
|
* @param z the value.
|
||||||
* @return the Striling's series error.
|
* @return the Striling's series error.
|
||||||
|
@ -117,7 +116,6 @@ public class SaddlePointExpansion {
|
||||||
* href="http://www.herine.net/stat/papers/dbinom.pdf">
|
* href="http://www.herine.net/stat/papers/dbinom.pdf">
|
||||||
* http://www.herine.net/stat/papers/dbinom.pdf</a></li>
|
* http://www.herine.net/stat/papers/dbinom.pdf</a></li>
|
||||||
* </ol>
|
* </ol>
|
||||||
* </p>
|
|
||||||
*
|
*
|
||||||
* @param x the x value.
|
* @param x the x value.
|
||||||
* @param mu the average.
|
* @param mu the average.
|
||||||
|
|
|
@ -172,7 +172,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
* If {@code x} is more than 40 standard deviations from the mean, 0 or 1
|
||||||
* is returned, as in these cases the actual value is within
|
* is returned, as in these cases the actual value is within
|
||||||
* {@code Double.MIN_VALUE} of 0 or 1.
|
* {@code Double.MIN_VALUE} of 0 or 1.
|
||||||
|
@ -238,7 +238,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For mean parameter {@code mu}, the mean is {@code mu}.
|
* For mean parameter {@code mu}, the mean is {@code mu}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalMean() {
|
public double getNumericalMean() {
|
||||||
|
@ -247,7 +247,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
* For standard deviation parameter {@code s}, the variance is {@code s^2}.
|
||||||
*/
|
*/
|
||||||
public double getNumericalVariance() {
|
public double getNumericalVariance() {
|
||||||
|
@ -257,7 +257,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The lower bound of the support is always negative infinity
|
* The lower bound of the support is always negative infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -270,7 +270,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The upper bound of the support is always positive infinity
|
* The upper bound of the support is always positive infinity
|
||||||
* no matter the parameters.
|
* no matter the parameters.
|
||||||
*
|
*
|
||||||
|
@ -297,7 +297,7 @@ public class TruncatedNormalDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The support of this distribution is connected.
|
* The support of this distribution is connected.
|
||||||
*
|
*
|
||||||
* @return {@code true}
|
* @return {@code true}
|
||||||
|
|
|
@ -105,7 +105,7 @@ public class UniformDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For lower bound {@code lower} and upper bound {@code upper}, the mean is
|
* For lower bound {@code lower} and upper bound {@code upper}, the mean is
|
||||||
* {@code 0.5 * (lower + upper)}.
|
* {@code 0.5 * (lower + upper)}.
|
||||||
*/
|
*/
|
||||||
|
@ -115,7 +115,7 @@ public class UniformDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* For lower bound {@code lower} and upper bound {@code upper}, the
|
* For lower bound {@code lower} and upper bound {@code upper}, the
|
||||||
* variance is {@code (upper - lower)^2 / 12}.
|
* variance is {@code (upper - lower)^2 / 12}.
|
||||||
*/
|
*/
|
||||||
|
@ -126,7 +126,7 @@ public class UniformDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The lower bound of the support is equal to the lower bound parameter
|
* The lower bound of the support is equal to the lower bound parameter
|
||||||
* of the distribution.
|
* of the distribution.
|
||||||
*
|
*
|
||||||
|
@ -138,7 +138,7 @@ public class UniformDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The upper bound of the support is equal to the upper bound parameter
|
* The upper bound of the support is equal to the upper bound parameter
|
||||||
* of the distribution.
|
* of the distribution.
|
||||||
*
|
*
|
||||||
|
@ -164,7 +164,7 @@ public class UniformDistribution extends BaseDistribution {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
* <p/>
|
* <p>
|
||||||
* The support of this distribution is connected.
|
* The support of this distribution is connected.
|
||||||
*
|
*
|
||||||
* @return {@code true}
|
* @return {@code true}
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class NDArrayCreationUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/** Get an array of INDArrays (2d) all with the specified shape. Pair<INDArray,String> returned to aid
|
/** Get an array of INDArrays (2d) all with the specified shape. {@code Pair<INDArray,String>} returned to aid
|
||||||
* debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments)
|
* debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments)
|
||||||
* Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension,
|
* Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension,
|
||||||
* etc to an original array.
|
* etc to an original array.
|
||||||
|
@ -88,7 +88,7 @@ public class NDArrayCreationUtil {
|
||||||
* eg. rank 2: 1,1; 1,2; 2,1; 2,2; 3,4
|
* eg. rank 2: 1,1; 1,2; 2,1; 2,2; 3,4
|
||||||
* Motivated by TADs that often hit bugs when a "1" occurs as the size of a dimension
|
* Motivated by TADs that often hit bugs when a "1" occurs as the size of a dimension
|
||||||
*
|
*
|
||||||
* @param rank any rank including true scalars i.e rank >= 0
|
* @param rank any rank including true scalars i.e rank >= 0
|
||||||
* @param order what order array to return i.e 'c' or 'f' order arrays
|
* @param order what order array to return i.e 'c' or 'f' order arrays
|
||||||
* @return List of arrays and the shapes as strings
|
* @return List of arrays and the shapes as strings
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -355,7 +355,7 @@ public class AsyncDataSetIterator implements DataSetIterator {
|
||||||
* yet been called, or the {@code remove} method has already
|
* yet been called, or the {@code remove} method has already
|
||||||
* been called after the last call to the {@code next}
|
* been called after the last call to the {@code next}
|
||||||
* method
|
* method
|
||||||
* @implSpec The default implementation throws an instance of
|
* The default implementation throws an instance of
|
||||||
* {@link UnsupportedOperationException} and performs no other action.
|
* {@link UnsupportedOperationException} and performs no other action.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -299,7 +299,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator {
|
||||||
* yet been called, or the {@code remove} method has already
|
* yet been called, or the {@code remove} method has already
|
||||||
* been called after the last call to the {@code next}
|
* been called after the last call to the {@code next}
|
||||||
* method
|
* method
|
||||||
* @implSpec The default implementation throws an instance of
|
* The default implementation throws an instance of
|
||||||
* {@link UnsupportedOperationException} and performs no other action.
|
* {@link UnsupportedOperationException} and performs no other action.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -560,7 +560,6 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @Deprecated
|
|
||||||
* Subtract by the column means and divide by the standard deviation
|
* Subtract by the column means and divide by the standard deviation
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -117,7 +117,6 @@ public class KFoldIterator implements DataSetIterator {
|
||||||
/**
|
/**
|
||||||
* Shuffles the dataset and resets to the first fold
|
* Shuffles the dataset and resets to the first fold
|
||||||
*
|
*
|
||||||
* @return void
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void reset() {
|
||||||
|
@ -129,7 +128,7 @@ public class KFoldIterator implements DataSetIterator {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The number of examples in every fold is (N / k),
|
* The number of examples in every fold is (N / k),
|
||||||
* except when (N % k) > 0, when the first (N % k) folds contain (N / k) + 1 examples
|
* except when (N % k) > 0, when the first (N % k) folds contain (N / k) + 1 examples
|
||||||
*
|
*
|
||||||
* @return examples in a fold
|
* @return examples in a fold
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -49,7 +49,6 @@ public class TestDataSetIterator implements DataSetIterator {
|
||||||
* Initializes with a default batch of 5
|
* Initializes with a default batch of 5
|
||||||
*
|
*
|
||||||
* @param dataset the dataset to make the iterator from
|
* @param dataset the dataset to make the iterator from
|
||||||
* @param batch the batchsize for the iterator
|
|
||||||
*/
|
*/
|
||||||
public TestDataSetIterator(DataSet dataset) {
|
public TestDataSetIterator(DataSet dataset) {
|
||||||
this(dataset, 5);
|
this(dataset, 5);
|
||||||
|
|
|
@ -65,9 +65,9 @@ public class RandomProjection {
|
||||||
* The minimum number n' of components to guarantee the eps-embedding is
|
* The minimum number n' of components to guarantee the eps-embedding is
|
||||||
* given by:
|
* given by:
|
||||||
*
|
*
|
||||||
* n' >= 4 log(n) / (eps² / 2 - eps³ / 3)
|
* {@code n' >= 4 log(n) / (eps² / 2 - eps³ / 3)}
|
||||||
*
|
*
|
||||||
* see http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1
|
* http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1
|
||||||
* @param n Number of samples. If an array is given, it will compute
|
* @param n Number of samples. If an array is given, it will compute
|
||||||
* a safe number of components array-wise.
|
* a safe number of components array-wise.
|
||||||
* @param eps Maximum distortion rate as defined by the Johnson-Lindenstrauss lemma.
|
* @param eps Maximum distortion rate as defined by the Johnson-Lindenstrauss lemma.
|
||||||
|
|
|
@ -30,7 +30,6 @@ public interface EnvironmentalAction {
|
||||||
/**
|
/**
|
||||||
* This method will be executed with corresponding Env Var value
|
* This method will be executed with corresponding Env Var value
|
||||||
*
|
*
|
||||||
* @param name
|
|
||||||
* @param value
|
* @param value
|
||||||
*/
|
*/
|
||||||
void process(String value);
|
void process(String value);
|
||||||
|
|
|
@ -276,7 +276,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
||||||
* Rotate a matrix 90 degrees
|
* Rotate a matrix 90 degrees
|
||||||
*
|
*
|
||||||
* @param toRotate the matrix to rotate
|
* @param toRotate the matrix to rotate
|
||||||
* @return the rotated matrix
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void rot90(INDArray toRotate) {
|
public void rot90(INDArray toRotate) {
|
||||||
|
|
|
@ -33,7 +33,7 @@ public interface BlasWrapper {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute x <-> y (swap two matrices)
|
* Compute {@code x <-> y} (swap two matrices)
|
||||||
*/
|
*/
|
||||||
INDArray swap(INDArray x, INDArray y);
|
INDArray swap(INDArray x, INDArray y);
|
||||||
|
|
||||||
|
@ -69,14 +69,14 @@ public interface BlasWrapper {
|
||||||
INDArray scal(double alpha, INDArray x);
|
INDArray scal(double alpha, INDArray x);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute x <- alpha * x (scale a matrix)
|
* Compute {@code x <- alpha * x} (scale a matrix)
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
INDArray scal(float alpha, INDArray x);
|
INDArray scal(float alpha, INDArray x);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute y <- x (copy a matrix)
|
* Compute {@code y <- x} (copy a matrix)
|
||||||
*/
|
*/
|
||||||
INDArray copy(INDArray x, INDArray y);
|
INDArray copy(INDArray x, INDArray y);
|
||||||
|
|
||||||
|
@ -84,13 +84,13 @@ public interface BlasWrapper {
|
||||||
INDArray axpy(double da, INDArray dx, INDArray dy);
|
INDArray axpy(double da, INDArray dx, INDArray dy);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute y <- alpha * x + y (elementwise addition)
|
* Compute {@code y <- alpha * x + y }(elementwise addition)
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
INDArray axpy(float da, INDArray dx, INDArray dy);
|
INDArray axpy(float da, INDArray dx, INDArray dy);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute y <- y + x * alpha
|
* Compute {@code y <- y + x * alpha}
|
||||||
* @param da the alpha to multiply by
|
* @param da the alpha to multiply by
|
||||||
* @param dx
|
* @param dx
|
||||||
* @param dy
|
* @param dy
|
||||||
|
@ -130,7 +130,7 @@ public interface BlasWrapper {
|
||||||
INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y);
|
INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute y <- alpha*op(a)*x + beta * y (general matrix vector
|
* Compute {@code y <- alpha*op(a)*x + beta * y} (general matrix vector
|
||||||
* multiplication)
|
* multiplication)
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
@ -142,7 +142,7 @@ public interface BlasWrapper {
|
||||||
INDArray ger(double alpha, INDArray x, INDArray y, INDArray a);
|
INDArray ger(double alpha, INDArray x, INDArray y, INDArray a);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute A <- alpha * x * y^T + A (general rank-1 update)
|
* Compute {@code A <- alpha * x * y^T + A} (general rank-1 update)
|
||||||
*/
|
*/
|
||||||
INDArray ger(float alpha, INDArray x, INDArray y, INDArray a);
|
INDArray ger(float alpha, INDArray x, INDArray y, INDArray a);
|
||||||
|
|
||||||
|
@ -193,14 +193,14 @@ public interface BlasWrapper {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generalized Least Squares via *GELSD.
|
* Generalized Least Squares via *GELSD.
|
||||||
* <p/>
|
* <p>
|
||||||
* Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows
|
* Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows
|
||||||
* than columns.
|
* than columns.
|
||||||
* <p/>
|
* <p>
|
||||||
* For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain
|
* For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain
|
||||||
* the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix.
|
* the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix.
|
||||||
* <p/>
|
* <p>
|
||||||
* Likewise, if m > n, the solution consists only of the first n rows of B.
|
* Likewise, if m > n, the solution consists only of the first n rows of B.
|
||||||
*
|
*
|
||||||
* @param A an (m,n) matrix
|
* @param A an (m,n) matrix
|
||||||
* @param B an (max(m,n), k) matrix (well, at least)
|
* @param B an (max(m,n), k) matrix (well, at least)
|
||||||
|
|
|
@ -193,7 +193,6 @@ public interface NDArrayFactory {
|
||||||
* Rotate a matrix 90 degrees
|
* Rotate a matrix 90 degrees
|
||||||
*
|
*
|
||||||
* @param toRotate the matrix to rotate
|
* @param toRotate the matrix to rotate
|
||||||
* @return the rotated matrix
|
|
||||||
*/
|
*/
|
||||||
void rot90(INDArray toRotate);
|
void rot90(INDArray toRotate);
|
||||||
|
|
||||||
|
@ -340,7 +339,6 @@ public interface NDArrayFactory {
|
||||||
*
|
*
|
||||||
* @param array the ndarray to shuffle
|
* @param array the ndarray to shuffle
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
void shuffle(INDArray array, Random rnd, int... dimension);
|
void shuffle(INDArray array, Random rnd, int... dimension);
|
||||||
|
|
||||||
|
@ -350,7 +348,6 @@ public interface NDArrayFactory {
|
||||||
*
|
*
|
||||||
* @param array the ndarray to shuffle
|
* @param array the ndarray to shuffle
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
void shuffle(Collection<INDArray> array, Random rnd, int... dimension);
|
void shuffle(Collection<INDArray> array, Random rnd, int... dimension);
|
||||||
|
|
||||||
|
@ -360,7 +357,6 @@ public interface NDArrayFactory {
|
||||||
*
|
*
|
||||||
* @param array the ndarray to shuffle
|
* @param array the ndarray to shuffle
|
||||||
* @param dimensions the dimensions to do the shuffle
|
* @param dimensions the dimensions to do the shuffle
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions);
|
void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions);
|
||||||
|
|
||||||
|
@ -1370,9 +1366,9 @@ public interface NDArrayFactory {
|
||||||
INDArray createFromNpyFile(File file);
|
INDArray createFromNpyFile(File file);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a Map<String, INDArray> from given npz file.
|
* Create a {@code Map<String, INDArray>} from given npz file.
|
||||||
* @param file the file to create the map from
|
* @param file the file to create the map from
|
||||||
* @return Map<String, INDArray>
|
* @return {@code Map<String, INDArray>}
|
||||||
*/
|
*/
|
||||||
Map<String, INDArray> createFromNpzFile(File file) throws Exception;
|
Map<String, INDArray> createFromNpzFile(File file) throws Exception;
|
||||||
|
|
||||||
|
|
|
@ -1441,7 +1441,7 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype.
|
* See {@link #createBuffer(DataType dataType, long length, boolean initialize)} with default datatype.
|
||||||
*/
|
*/
|
||||||
public static DataBuffer createBuffer(long length, boolean initialize) {
|
public static DataBuffer createBuffer(long length, boolean initialize) {
|
||||||
return createBuffer(Nd4j.dataType(), length, initialize);
|
return createBuffer(Nd4j.dataType(), length, initialize);
|
||||||
|
@ -2828,7 +2828,7 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...))
|
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)}
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) {
|
public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) {
|
||||||
|
@ -3306,7 +3306,7 @@ public class Nd4j {
|
||||||
* Generate an array with random values generated according to a binomial distribution with the specified
|
* Generate an array with random values generated according to a binomial distribution with the specified
|
||||||
* number of trials and probability
|
* number of trials and probability
|
||||||
*
|
*
|
||||||
* @param nTrials Number of trials. Must be >= 0
|
* @param nTrials Number of trials. Must be >= 0
|
||||||
* @param p Probability. Must be in range 0 to 1
|
* @param p Probability. Must be in range 0 to 1
|
||||||
* @param shape Shape of the result array
|
* @param shape Shape of the result array
|
||||||
* @return Result array
|
* @return Result array
|
||||||
|
@ -3319,7 +3319,7 @@ public class Nd4j {
|
||||||
* Fill the target array with random values generated according to a binomial distribution with the specified
|
* Fill the target array with random values generated according to a binomial distribution with the specified
|
||||||
* number of trials and probability
|
* number of trials and probability
|
||||||
*
|
*
|
||||||
* @param nTrials Number of trials. Must be >= 0
|
* @param nTrials Number of trials. Must be >= 0
|
||||||
* @param p Probability. Must be in range 0 to 1
|
* @param p Probability. Must be in range 0 to 1
|
||||||
* @param target Result array
|
* @param target Result array
|
||||||
* @return Result array
|
* @return Result array
|
||||||
|
@ -3333,7 +3333,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* Exponential distribution: P(x) = lambda * exp(-lambda * x)
|
* Exponential distribution: P(x) = lambda * exp(-lambda * x)
|
||||||
*
|
*
|
||||||
* @param lambda Must be > 0
|
* @param lambda Must be > 0
|
||||||
* @param shape Shape of the array to generate
|
* @param shape Shape of the array to generate
|
||||||
*/
|
*/
|
||||||
public static INDArray randomExponential(double lambda, long... shape) {
|
public static INDArray randomExponential(double lambda, long... shape) {
|
||||||
|
@ -3341,9 +3341,9 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Exponential distribution: P(x) = lambda * exp(-lambda * x)
|
* Exponential distribution: {@code P(x) = lambda * exp(-lambda * x)}
|
||||||
*
|
*
|
||||||
* @param lambda Must be > 0
|
* @param lambda Must be > 0
|
||||||
* @param target Array to hold the result
|
* @param target Array to hold the result
|
||||||
*/
|
*/
|
||||||
public static INDArray randomExponential(double lambda, INDArray target) {
|
public static INDArray randomExponential(double lambda, INDArray target) {
|
||||||
|
@ -3925,7 +3925,7 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link @see #create(int, int, int[], char)}
|
* See {@link Nd4j#create(int, int, int[], char)}
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(int rows, int columns, int[] stride) {
|
public static INDArray zeros(int rows, int columns, int[] stride) {
|
||||||
return create(rows, columns, stride, order());
|
return create(rows, columns, stride, order());
|
||||||
|
@ -4630,7 +4630,7 @@ public class Nd4j {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
|
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
|
||||||
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3]
|
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3]
|
||||||
*
|
*
|
||||||
* @param arrs Arrays to vstack
|
* @param arrs Arrays to vstack
|
||||||
*/
|
*/
|
||||||
|
@ -4646,7 +4646,7 @@ public class Nd4j {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
|
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
|
||||||
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3]
|
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3]
|
||||||
*
|
*
|
||||||
* @param arrs Arrays to vstack
|
* @param arrs Arrays to vstack
|
||||||
*/
|
*/
|
||||||
|
@ -5462,7 +5462,7 @@ public class Nd4j {
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1)
|
{@code >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1)
|
||||||
array([[ 1, 2, 3],
|
array([[ 1, 2, 3],
|
||||||
[ 4, 5, 6],
|
[ 4, 5, 6],
|
||||||
[ 0, 8, 9],
|
[ 0, 8, 9],
|
||||||
|
@ -5473,6 +5473,7 @@ public class Nd4j {
|
||||||
mask = tri(*m.shape[-2:], k=k-1, dtype=bool)
|
mask = tri(*m.shape[-2:], k=k-1, dtype=bool)
|
||||||
|
|
||||||
return where(mask, zeros(1, m.dtype), m)
|
return where(mask, zeros(1, m.dtype), m)
|
||||||
|
}
|
||||||
|
|
||||||
* @param m source array
|
* @param m source array
|
||||||
* @param k to zero below the k-th diagonal
|
* @param k to zero below the k-th diagonal
|
||||||
|
@ -5517,8 +5518,8 @@ public class Nd4j {
|
||||||
* @param n number of rows in the array
|
* @param n number of rows in the array
|
||||||
* @param m number of columns in the array ( can be just equal to n)
|
* @param m number of columns in the array ( can be just equal to n)
|
||||||
* @param k The sub-diagonal at and below which the array is filled.
|
* @param k The sub-diagonal at and below which the array is filled.
|
||||||
`k` = 0 is the main diagonal, while `k` < 0 is below it,
|
`k` = 0 is the main diagonal, while `k` > 0 is below it,
|
||||||
and `k` > 0 is above. The default is 0.
|
and `k` > 0 is above. The default is 0.
|
||||||
* @return array with ones at and below the given diagonal and zeros elsewhere
|
* @return array with ones at and below the given diagonal and zeros elsewhere
|
||||||
*/
|
*/
|
||||||
public static INDArray tri(int n,int m,int k) {
|
public static INDArray tri(int n,int m,int k) {
|
||||||
|
|
|
@ -269,14 +269,14 @@ public abstract class Nd4jBackend {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs a new exception with the specified cause and a detail
|
* Constructs a new exception with the specified cause and a detail
|
||||||
* message of <tt>(cause==null ? null : cause.toString())</tt> (which
|
* message of {@code (cause==null ? null : cause.toString())} (which
|
||||||
* typically contains the class and detail message of <tt>cause</tt>).
|
* typically contains the class and detail message of cause).
|
||||||
* This constructor is useful for exceptions that are little more than
|
* This constructor is useful for exceptions that are little more than
|
||||||
* wrappers for other throwables (for example, {@link
|
* wrappers for other throwables (for example, {@link
|
||||||
* PrivilegedActionException}).
|
* PrivilegedActionException}).
|
||||||
*
|
*
|
||||||
* @param cause the cause (which is saved for later retrieval by the
|
* @param cause the cause (which is saved for later retrieval by the
|
||||||
* {@link #getCause()} method). (A <tt>null</tt> value is
|
* {@link #getCause()} method). (A null value is
|
||||||
* permitted, and indicates that the cause is nonexistent or
|
* permitted, and indicates that the cause is nonexistent or
|
||||||
* unknown.)
|
* unknown.)
|
||||||
* @since 1.4
|
* @since 1.4
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -173,13 +173,13 @@ public class Indices {
|
||||||
/**
|
/**
|
||||||
* Fill in the missing indices to be the
|
* Fill in the missing indices to be the
|
||||||
* same length as the original shape.
|
* same length as the original shape.
|
||||||
* <p/>
|
* <p>
|
||||||
* Think of this as what fills in the indices for numpy or matlab:
|
* Think of this as what fills in the indices for numpy or matlab:
|
||||||
* Given a which is (4,3,2) in numpy:
|
* Given a which is (4,3,2) in numpy:
|
||||||
* <p/>
|
* <p>
|
||||||
* a[1:3] is filled in by the rest
|
* a[1:3] is filled in by the rest
|
||||||
* to give back the full slice
|
* to give back the full slice
|
||||||
* <p/>
|
* <p>
|
||||||
* This algorithm fills in that delta
|
* This algorithm fills in that delta
|
||||||
*
|
*
|
||||||
* @param shape the original shape
|
* @param shape the original shape
|
||||||
|
@ -244,7 +244,7 @@ public class Indices {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the shape for the given set of indices.
|
* Calculate the shape for the given set of indices.
|
||||||
* <p/>
|
* <p>
|
||||||
* The shape is defined as (for each dimension)
|
* The shape is defined as (for each dimension)
|
||||||
* the difference between the end index + 1 and
|
* the difference between the end index + 1 and
|
||||||
* the begin index
|
* the begin index
|
||||||
|
@ -344,12 +344,12 @@ public class Indices {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the shape for the given set of indices and offsets.
|
* Calculate the shape for the given set of indices and offsets.
|
||||||
* <p/>
|
* <p>
|
||||||
* The shape is defined as (for each dimension)
|
* The shape is defined as (for each dimension)
|
||||||
* the difference between the end index + 1 and
|
* the difference between the end index + 1 and
|
||||||
* the begin index
|
* the begin index
|
||||||
* <p/>
|
* <p>
|
||||||
* If specified, this will check for whether any of the indices are >= to end - 1
|
* If specified, this will check for whether any of the indices are >= to end - 1
|
||||||
* and if so, prune it down
|
* and if so, prune it down
|
||||||
*
|
*
|
||||||
* @param shape the original shape
|
* @param shape the original shape
|
||||||
|
|
|
@ -90,7 +90,6 @@ public class AdaBeliefUpdater implements GradientUpdater<AdaBelief> {
|
||||||
*
|
*
|
||||||
* @param gradient the gradient to get the update for
|
* @param gradient the gradient to get the update for
|
||||||
* @param iteration
|
* @param iteration
|
||||||
* @return the gradient
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||||
|
|
|
@ -32,6 +32,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
|
public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
|
||||||
|
|
||||||
public static final String MSG_STATE = "msg";
|
public static final String MSG_STATE = "msg";
|
||||||
public static final String MSDX_STATE = "msdx";
|
public static final String MSDX_STATE = "msdx";
|
||||||
|
|
||||||
|
@ -41,15 +42,17 @@ public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
|
||||||
private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1
|
private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public AdaDeltaUpdater(AdaDelta config) {
|
public AdaDeltaUpdater(AdaDelta config) {
|
||||||
this.config = config;
|
this.config = config;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setState(Map<String, INDArray> stateMap, boolean initialize) {
|
public void setState(Map<String, INDArray> stateMap, boolean initialize) {
|
||||||
if(!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE) || stateMap.size() != 2){
|
if (!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE)
|
||||||
throw new IllegalStateException("State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys " + stateMap.keySet());
|
|| stateMap.size() != 2) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys "
|
||||||
|
+ stateMap.keySet());
|
||||||
}
|
}
|
||||||
this.msg = stateMap.get(MSG_STATE);
|
this.msg = stateMap.get(MSG_STATE);
|
||||||
this.msdx = stateMap.get(MSDX_STATE);
|
this.msdx = stateMap.get(MSDX_STATE);
|
||||||
|
@ -64,11 +67,14 @@ public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
|
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder,
|
||||||
if (!viewArray.isRowVector())
|
boolean initialize) {
|
||||||
|
if (!viewArray.isRowVector()) {
|
||||||
throw new IllegalArgumentException("Invalid input: expect row vector input");
|
throw new IllegalArgumentException("Invalid input: expect row vector input");
|
||||||
if (initialize)
|
}
|
||||||
|
if (initialize) {
|
||||||
viewArray.assign(0);
|
viewArray.assign(0);
|
||||||
|
}
|
||||||
long length = viewArray.length();
|
long length = viewArray.length();
|
||||||
this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
|
this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
|
||||||
this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
|
this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
|
||||||
|
@ -76,23 +82,22 @@ public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
|
||||||
//Reshape to match the expected shape of the input gradient arrays
|
//Reshape to match the expected shape of the input gradient arrays
|
||||||
this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f');
|
this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f');
|
||||||
this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f');
|
this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f');
|
||||||
if (msg == null || msdx == null)
|
if (msg == null || msdx == null) {
|
||||||
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
|
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the updated gradient for the given gradient
|
* Get the updated gradient for the given gradient and also update the state of ada delta.
|
||||||
* and also update the state of ada delta.
|
|
||||||
*
|
*
|
||||||
* @param gradient the gradient to get the
|
* @param gradient the gradient to get the updated gradient for
|
||||||
* updated gradient for
|
|
||||||
* @param iteration
|
* @param iteration
|
||||||
* @return the update gradient
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||||
if (msg == null || msdx == null)
|
if (msg == null || msdx == null) {
|
||||||
throw new IllegalStateException("Updater has not been initialized with view state");
|
throw new IllegalStateException("Updater has not been initialized with view state");
|
||||||
|
}
|
||||||
|
|
||||||
double rho = config.getRho();
|
double rho = config.getRho();
|
||||||
double epsilon = config.getEpsilon();
|
double epsilon = config.getEpsilon();
|
||||||
|
@ -104,6 +109,7 @@ public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
|
||||||
//Note: negative is applied in the DL4J step function: params -= update rather than params += update
|
//Note: negative is applied in the DL4J step function: params -= update rather than params += update
|
||||||
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
|
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
|
||||||
|
|
||||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon));
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho,
|
||||||
|
epsilon));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,7 +92,6 @@ public class AdaMaxUpdater implements GradientUpdater<AdaMax> {
|
||||||
*
|
*
|
||||||
* @param gradient the gradient to get the update for
|
* @param gradient the gradient to get the update for
|
||||||
* @param iteration
|
* @param iteration
|
||||||
* @return the gradient
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||||
|
|
|
@ -93,7 +93,6 @@ public class AdamUpdater implements GradientUpdater<Adam> {
|
||||||
*
|
*
|
||||||
* @param gradient the gradient to get the update for
|
* @param gradient the gradient to get the update for
|
||||||
* @param iteration
|
* @param iteration
|
||||||
* @return the gradient
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||||
|
|
|
@ -48,7 +48,6 @@ public interface GradientUpdater<T extends IUpdater> {
|
||||||
*
|
*
|
||||||
* @param gradient the gradient to modify
|
* @param gradient the gradient to modify
|
||||||
* @param iteration
|
* @param iteration
|
||||||
* @return the modified gradient
|
|
||||||
*/
|
*/
|
||||||
void applyUpdater(INDArray gradient, int iteration, int epoch);
|
void applyUpdater(INDArray gradient, int iteration, int epoch);
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,7 +92,6 @@ public class NadamUpdater implements GradientUpdater<Nadam> {
|
||||||
*
|
*
|
||||||
* @param gradient the gradient to get the update for
|
* @param gradient the gradient to get the update for
|
||||||
* @param iteration
|
* @param iteration
|
||||||
* @return the gradient
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||||
|
|
|
@ -77,7 +77,6 @@ public class NesterovsUpdater implements GradientUpdater<Nesterovs> {
|
||||||
*
|
*
|
||||||
* @param gradient the gradient to get the update for
|
* @param gradient the gradient to get the update for
|
||||||
* @param iteration
|
* @param iteration
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||||
|
|
|
@ -152,12 +152,12 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
|
||||||
/**
|
/**
|
||||||
* Returns the value to which the specified key is mapped,
|
* Returns the value to which the specified key is mapped,
|
||||||
* or {@code null} if this map contains no mapping for the key.
|
* or {@code null} if this map contains no mapping for the key.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>More formally, if this map contains a mapping from a key
|
* <p>More formally, if this map contains a mapping from a key
|
||||||
* {@code k} to a value {@code v} such that {@code (key==null ? k==null :
|
* {@code k} to a value {@code v} such that {@code (key==null ? k==null :
|
||||||
* key.equals(k))}, then this method returns {@code v}; otherwise
|
* key.equals(k))}, then this method returns {@code v}; otherwise
|
||||||
* it returns {@code null}. (There can be at most one such mapping.)
|
* it returns {@code null}. (There can be at most one such mapping.)
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this map permits null values, then a return value of
|
* <p>If this map permits null values, then a return value of
|
||||||
* {@code null} does not <i>necessarily</i> indicate that the map
|
* {@code null} does not <i>necessarily</i> indicate that the map
|
||||||
* contains no mapping for the key; it's also possible that the map
|
* contains no mapping for the key; it's also possible that the map
|
||||||
|
@ -214,15 +214,15 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
|
||||||
* from key <tt>k</tt> to value <tt>v</tt> such that
|
* from key <tt>k</tt> to value <tt>v</tt> such that
|
||||||
* <code>(key==null ? k==null : key.equals(k))</code>, that mapping
|
* <code>(key==null ? k==null : key.equals(k))</code>, that mapping
|
||||||
* is removed. (The map can contain at most one such mapping.)
|
* is removed. (The map can contain at most one such mapping.)
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>Returns the value to which this map previously associated the key,
|
* <p>Returns the value to which this map previously associated the key,
|
||||||
* or <tt>null</tt> if the map contained no mapping for the key.
|
* or <tt>null</tt> if the map contained no mapping for the key.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this map permits null values, then a return value of
|
* <p>If this map permits null values, then a return value of
|
||||||
* <tt>null</tt> does not <i>necessarily</i> indicate that the map
|
* <tt>null</tt> does not <i>necessarily</i> indicate that the map
|
||||||
* contained no mapping for the key; it's also possible that the map
|
* contained no mapping for the key; it's also possible that the map
|
||||||
* explicitly mapped the key to <tt>null</tt>.
|
* explicitly mapped the key to <tt>null</tt>.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>The map will not contain a mapping for the specified key once the
|
* <p>The map will not contain a mapping for the specified key once the
|
||||||
* call returns.
|
* call returns.
|
||||||
*
|
*
|
||||||
|
|
|
@ -108,12 +108,12 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
|
||||||
* If this applyTransformToDestination makes any guarantees as to what order its elements
|
* If this applyTransformToDestination makes any guarantees as to what order its elements
|
||||||
* are returned by its iterator, this method must return the
|
* are returned by its iterator, this method must return the
|
||||||
* elements in the same order.
|
* elements in the same order.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>The returned array will be "safe" in that no references to it
|
* <p>The returned array will be "safe" in that no references to it
|
||||||
* are maintained by this applyTransformToDestination. (In other words, this method must
|
* are maintained by this applyTransformToDestination. (In other words, this method must
|
||||||
* allocate a new array even if this applyTransformToDestination is backed by an array).
|
* allocate a new array even if this applyTransformToDestination is backed by an array).
|
||||||
* The caller is thus free to modify the returned array.
|
* The caller is thus free to modify the returned array.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>This method acts as bridge between array-based and collection-based
|
* <p>This method acts as bridge between array-based and collection-based
|
||||||
* APIs.
|
* APIs.
|
||||||
*
|
*
|
||||||
|
@ -130,27 +130,27 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
|
||||||
* If the applyTransformToDestination fits in the specified array, it is returned therein.
|
* If the applyTransformToDestination fits in the specified array, it is returned therein.
|
||||||
* Otherwise, a new array is allocated with the runtime type of the
|
* Otherwise, a new array is allocated with the runtime type of the
|
||||||
* specified array and the size of this applyTransformToDestination.
|
* specified array and the size of this applyTransformToDestination.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this applyTransformToDestination fits in the specified array with room to spare
|
* <p>If this applyTransformToDestination fits in the specified array with room to spare
|
||||||
* (i.e., the array has more elements than this applyTransformToDestination), the element in
|
* (i.e., the array has more elements than this applyTransformToDestination), the element in
|
||||||
* the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to
|
* the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to
|
||||||
* <tt>null</tt>. (This is useful in determining the length of this
|
* <tt>null</tt>. (This is useful in determining the length of this
|
||||||
* applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain
|
* applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain
|
||||||
* any null elements.)
|
* any null elements.)
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this applyTransformToDestination makes any guarantees as to what order its elements
|
* <p>If this applyTransformToDestination makes any guarantees as to what order its elements
|
||||||
* are returned by its iterator, this method must return the elements
|
* are returned by its iterator, this method must return the elements
|
||||||
* in the same order.
|
* in the same order.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>Like the {@link #toArray()} method, this method acts as bridge between
|
* <p>Like the {@link #toArray()} method, this method acts as bridge between
|
||||||
* array-based and collection-based APIs. Further, this method allows
|
* array-based and collection-based APIs. Further, this method allows
|
||||||
* precise control over the runtime type of the output array, and may,
|
* precise control over the runtime type of the output array, and may,
|
||||||
* under certain circumstances, be used to save allocation costs.
|
* under certain circumstances, be used to save allocation costs.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings.
|
* <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings.
|
||||||
* The following code can be used to dump the applyTransformToDestination into a newly allocated
|
* The following code can be used to dump the applyTransformToDestination into a newly allocated
|
||||||
* array of <tt>String</tt>:
|
* array of <tt>String</tt>:
|
||||||
* <p/>
|
* <p>
|
||||||
* <pre>
|
* <pre>
|
||||||
* String[] y = x.toArray(new String[0]);</pre>
|
* String[] y = x.toArray(new String[0]);</pre>
|
||||||
*
|
*
|
||||||
|
@ -181,7 +181,7 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
|
||||||
* unchanged and returns <tt>false</tt>. In combination with the
|
* unchanged and returns <tt>false</tt>. In combination with the
|
||||||
* restriction on constructors, this ensures that sets never contain
|
* restriction on constructors, this ensures that sets never contain
|
||||||
* duplicate elements.
|
* duplicate elements.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>The stipulation above does not imply that sets must accept all
|
* <p>The stipulation above does not imply that sets must accept all
|
||||||
* elements; sets may refuse to add any particular element, including
|
* elements; sets may refuse to add any particular element, including
|
||||||
* <tt>null</tt>, and throw an exception, as described in the
|
* <tt>null</tt>, and throw an exception, as described in the
|
||||||
|
|
|
@ -204,9 +204,9 @@ public class ArrayUtil {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Credit to mikio braun from jblas
|
* Credit to mikio braun from jblas
|
||||||
* <p/>
|
* <p>
|
||||||
* Create a random permutation of the numbers 0, ..., size - 1.
|
* Create a random permutation of the numbers 0, ..., size - 1.
|
||||||
* <p/>
|
* <p>
|
||||||
* see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145
|
* see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145
|
||||||
*/
|
*/
|
||||||
public static int[] randomPermutation(int size) {
|
public static int[] randomPermutation(int size) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ public class HelperUtils {
|
||||||
if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) {
|
if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) {
|
||||||
if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) {
|
if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) {
|
||||||
log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName);
|
log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName);
|
||||||
helperRet = (LayerHelper) DL4JClassLoading.<LayerHelper>createNewInstance(
|
helperRet = DL4JClassLoading.<LayerHelper>createNewInstance(
|
||||||
cudnnHelperClassName,
|
cudnnHelperClassName,
|
||||||
(Class<? super LayerHelper>) layerHelperSuperClass,
|
(Class<? super LayerHelper>) layerHelperSuperClass,
|
||||||
new Object[]{arguments});
|
new Object[]{arguments});
|
||||||
|
@ -76,7 +76,7 @@ public class HelperUtils {
|
||||||
ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader();
|
ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader();
|
||||||
DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass);
|
DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass);
|
||||||
try {
|
try {
|
||||||
helperRet = (LayerHelper) DL4JClassLoading.<LayerHelper>createNewInstance(
|
helperRet = DL4JClassLoading.<LayerHelper>createNewInstance(
|
||||||
cudnnHelperClassName,
|
cudnnHelperClassName,
|
||||||
(Class<? super LayerHelper>) layerHelperSuperClass,
|
(Class<? super LayerHelper>) layerHelperSuperClass,
|
||||||
arguments);
|
arguments);
|
||||||
|
@ -99,7 +99,7 @@ public class HelperUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) {
|
} else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) {
|
||||||
helperRet = DL4JClassLoading.<LayerHelper>createNewInstance(
|
helperRet = DL4JClassLoading.createNewInstance(
|
||||||
oneDnnClassName,
|
oneDnnClassName,
|
||||||
arguments);
|
arguments);
|
||||||
log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName);
|
log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName);
|
||||||
|
|
|
@ -3,21 +3,27 @@ plugins {
|
||||||
id 'maven-publish'
|
id 'maven-publish'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
configurations.archives.artifacts.with { archives ->
|
configurations.archives.artifacts.with { archives ->
|
||||||
|
|
||||||
archives.each {
|
archives.each {
|
||||||
println(it.name)
|
println(it.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
//Todo clean this
|
//Todo clean this
|
||||||
api platform(project(":cavis-common-platform"))
|
api platform(project(":cavis-common-platform"))
|
||||||
api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise
|
//api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise
|
||||||
api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5"
|
api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5"
|
||||||
api 'org.slf4j:slf4j-simple:2.0.3'
|
api 'org.slf4j:slf4j-simple:2.0.3'
|
||||||
api 'org.slf4j:slf4j-api:2.0.3'
|
api 'org.slf4j:slf4j-api:2.0.3'
|
||||||
//api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86"
|
//TODO for the two below.. either platform specific uber jars or a single big one with all platforms
|
||||||
|
api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64"
|
||||||
|
//api group: "org.bytedeco", name: "javacpp", version: "1.5.7"
|
||||||
|
// api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu"
|
||||||
|
//api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT'
|
||||||
rootProject.getAllprojects().each { Project sproj ->
|
rootProject.getAllprojects().each { Project sproj ->
|
||||||
if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")
|
if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")
|
||||||
&& !sproj.name.equals("Cavis")
|
&& !sproj.name.equals("Cavis")
|
||||||
|
|
|
@ -152,12 +152,12 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
|
||||||
/**
|
/**
|
||||||
* Returns the value to which the specified key is mapped,
|
* Returns the value to which the specified key is mapped,
|
||||||
* or {@code null} if this map contains no mapping for the key.
|
* or {@code null} if this map contains no mapping for the key.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>More formally, if this map contains a mapping from a key
|
* <p>More formally, if this map contains a mapping from a key
|
||||||
* {@code k} to a value {@code v} such that {@code (key==null ? k==null :
|
* {@code k} to a value {@code v} such that {@code (key==null ? k==null :
|
||||||
* key.equals(k))}, then this method returns {@code v}; otherwise
|
* key.equals(k))}, then this method returns {@code v}; otherwise
|
||||||
* it returns {@code null}. (There can be at most one such mapping.)
|
* it returns {@code null}. (There can be at most one such mapping.)
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this map permits null values, then a return value of
|
* <p>If this map permits null values, then a return value of
|
||||||
* {@code null} does not <i>necessarily</i> indicate that the map
|
* {@code null} does not <i>necessarily</i> indicate that the map
|
||||||
* contains no mapping for the key; it's also possible that the map
|
* contains no mapping for the key; it's also possible that the map
|
||||||
|
@ -214,15 +214,15 @@ public class MultiDimensionalMap<K, T, V> implements Serializable {
|
||||||
* from key <tt>k</tt> to value <tt>v</tt> such that
|
* from key <tt>k</tt> to value <tt>v</tt> such that
|
||||||
* <code>(key==null ? k==null : key.equals(k))</code>, that mapping
|
* <code>(key==null ? k==null : key.equals(k))</code>, that mapping
|
||||||
* is removed. (The map can contain at most one such mapping.)
|
* is removed. (The map can contain at most one such mapping.)
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>Returns the value to which this map previously associated the key,
|
* <p>Returns the value to which this map previously associated the key,
|
||||||
* or <tt>null</tt> if the map contained no mapping for the key.
|
* or <tt>null</tt> if the map contained no mapping for the key.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this map permits null values, then a return value of
|
* <p>If this map permits null values, then a return value of
|
||||||
* <tt>null</tt> does not <i>necessarily</i> indicate that the map
|
* <tt>null</tt> does not <i>necessarily</i> indicate that the map
|
||||||
* contained no mapping for the key; it's also possible that the map
|
* contained no mapping for the key; it's also possible that the map
|
||||||
* explicitly mapped the key to <tt>null</tt>.
|
* explicitly mapped the key to <tt>null</tt>.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>The map will not contain a mapping for the specified key once the
|
* <p>The map will not contain a mapping for the specified key once the
|
||||||
* call returns.
|
* call returns.
|
||||||
*
|
*
|
||||||
|
|
|
@ -108,12 +108,12 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
|
||||||
* If this applyTransformToDestination makes any guarantees as to what order its elements
|
* If this applyTransformToDestination makes any guarantees as to what order its elements
|
||||||
* are returned by its iterator, this method must return the
|
* are returned by its iterator, this method must return the
|
||||||
* elements in the same order.
|
* elements in the same order.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>The returned array will be "safe" in that no references to it
|
* <p>The returned array will be "safe" in that no references to it
|
||||||
* are maintained by this applyTransformToDestination. (In other words, this method must
|
* are maintained by this applyTransformToDestination. (In other words, this method must
|
||||||
* allocate a new array even if this applyTransformToDestination is backed by an array).
|
* allocate a new array even if this applyTransformToDestination is backed by an array).
|
||||||
* The caller is thus free to modify the returned array.
|
* The caller is thus free to modify the returned array.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>This method acts as bridge between array-based and collection-based
|
* <p>This method acts as bridge between array-based and collection-based
|
||||||
* APIs.
|
* APIs.
|
||||||
*
|
*
|
||||||
|
@ -130,27 +130,27 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
|
||||||
* If the applyTransformToDestination fits in the specified array, it is returned therein.
|
* If the applyTransformToDestination fits in the specified array, it is returned therein.
|
||||||
* Otherwise, a new array is allocated with the runtime type of the
|
* Otherwise, a new array is allocated with the runtime type of the
|
||||||
* specified array and the size of this applyTransformToDestination.
|
* specified array and the size of this applyTransformToDestination.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this applyTransformToDestination fits in the specified array with room to spare
|
* <p>If this applyTransformToDestination fits in the specified array with room to spare
|
||||||
* (i.e., the array has more elements than this applyTransformToDestination), the element in
|
* (i.e., the array has more elements than this applyTransformToDestination), the element in
|
||||||
* the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to
|
* the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to
|
||||||
* <tt>null</tt>. (This is useful in determining the length of this
|
* <tt>null</tt>. (This is useful in determining the length of this
|
||||||
* applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain
|
* applyTransformToDestination <i>only</i> if the caller knows that this applyTransformToDestination does not contain
|
||||||
* any null elements.)
|
* any null elements.)
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>If this applyTransformToDestination makes any guarantees as to what order its elements
|
* <p>If this applyTransformToDestination makes any guarantees as to what order its elements
|
||||||
* are returned by its iterator, this method must return the elements
|
* are returned by its iterator, this method must return the elements
|
||||||
* in the same order.
|
* in the same order.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>Like the {@link #toArray()} method, this method acts as bridge between
|
* <p>Like the {@link #toArray()} method, this method acts as bridge between
|
||||||
* array-based and collection-based APIs. Further, this method allows
|
* array-based and collection-based APIs. Further, this method allows
|
||||||
* precise control over the runtime type of the output array, and may,
|
* precise control over the runtime type of the output array, and may,
|
||||||
* under certain circumstances, be used to save allocation costs.
|
* under certain circumstances, be used to save allocation costs.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings.
|
* <p>Suppose <tt>x</tt> is a applyTransformToDestination known to contain only strings.
|
||||||
* The following code can be used to dump the applyTransformToDestination into a newly allocated
|
* The following code can be used to dump the applyTransformToDestination into a newly allocated
|
||||||
* array of <tt>String</tt>:
|
* array of <tt>String</tt>:
|
||||||
* <p/>
|
* <p>
|
||||||
* <pre>
|
* <pre>
|
||||||
* String[] y = x.toArray(new String[0]);</pre>
|
* String[] y = x.toArray(new String[0]);</pre>
|
||||||
*
|
*
|
||||||
|
@ -181,7 +181,7 @@ public class MultiDimensionalSet<K, V> implements Set<Pair<K, V>> {
|
||||||
* unchanged and returns <tt>false</tt>. In combination with the
|
* unchanged and returns <tt>false</tt>. In combination with the
|
||||||
* restriction on constructors, this ensures that sets never contain
|
* restriction on constructors, this ensures that sets never contain
|
||||||
* duplicate elements.
|
* duplicate elements.
|
||||||
* <p/>
|
* <p>
|
||||||
* <p>The stipulation above does not imply that sets must accept all
|
* <p>The stipulation above does not imply that sets must accept all
|
||||||
* elements; sets may refuse to add any particular element, including
|
* elements; sets may refuse to add any particular element, including
|
||||||
* <tt>null</tt>, and throw an exception, as described in the
|
* <tt>null</tt>, and throw an exception, as described in the
|
||||||
|
|
|
@ -204,9 +204,9 @@ public class ArrayUtil {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Credit to mikio braun from jblas
|
* Credit to mikio braun from jblas
|
||||||
* <p/>
|
* <p>
|
||||||
* Create a random permutation of the numbers 0, ..., size - 1.
|
* Create a random permutation of the numbers 0, ..., size - 1.
|
||||||
* <p/>
|
* <p>
|
||||||
* see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145
|
* see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145
|
||||||
*/
|
*/
|
||||||
public static int[] randomPermutation(int size) {
|
public static int[] randomPermutation(int size) {
|
||||||
|
|
|
@ -148,7 +148,6 @@ include ':cavis-ui:cavis-ui-standalone'
|
||||||
include ':cavis-ui:cavis-ui-vertx'
|
include ':cavis-ui:cavis-ui-vertx'
|
||||||
include ':cavis-zoo'
|
include ':cavis-zoo'
|
||||||
include ':cavis-zoo:cavis-zoo-models'
|
include ':cavis-zoo:cavis-zoo-models'
|
||||||
|
|
||||||
include ':brutex-extended-tests'
|
include ':brutex-extended-tests'
|
||||||
include ':cavis-full'
|
include ':cavis-full'
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue