From 1c2ca75308a5780602b025bacfed588d9b927b46 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 21 Oct 2022 15:19:32 +0200 Subject: [PATCH] Fix javadoc and cleanup --- build.gradle | 19 + build_requirements.md | 6 +- .../autodiff/execution/input/Operands.java | 7 +- .../debugging/ExecDebuggingListener.java | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 6 +- .../internal/AbstractDependencyTracker.java | 22 +- .../samediff/internal/AbstractSession.java | 6 +- .../nd4j/autodiff/samediff/ops/SDBitwise.java | 206 +- .../nd4j/autodiff/samediff/ops/SDImage.java | 6 +- .../nd4j/autodiff/samediff/ops/SDLoss.java | 767 +- .../nd4j/autodiff/samediff/ops/SDMath.java | 2609 ++-- .../org/nd4j/autodiff/samediff/ops/SDNN.java | 831 +- .../nd4j/autodiff/samediff/ops/SDRandom.java | 4 +- .../evaluation/classification/Evaluation.java | 16 +- .../classification/EvaluationBinary.java | 4 +- .../nd4j/linalg/api/blas/impl/BaseLapack.java | 2 +- .../api/buffer/factory/DataBufferFactory.java | 4 +- .../nd4j/linalg/api/memory/MemoryManager.java | 1 - .../api/memory/MemoryWorkspaceManager.java | 8 +- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 10928 ++++++++-------- .../api/ndarray/BaseShapeInfoProvider.java | 2 +- .../org/nd4j/linalg/api/ndarray/INDArray.java | 22 +- .../linalg/api/ndarray/ShapeInfoProvider.java | 2 +- .../org/nd4j/linalg/api/ops/OpContext.java | 3 +- .../java/org/nd4j/linalg/api/rng/Random.java | 4 +- .../rng/distribution/BaseDistribution.java | 405 +- .../api/rng/distribution/Distribution.java | 8 +- .../impl/BinomialDistribution.java | 10 +- .../impl/ConstantDistribution.java | 12 +- .../impl/LogNormalDistribution.java | 8 +- .../distribution/impl/NormalDistribution.java | 6 - .../impl/OrthogonalDistribution.java | 411 +- .../impl/SaddlePointExpansion.java | 2 - .../impl/TruncatedNormalDistribution.java | 12 +- .../impl/UniformDistribution.java | 10 +- .../linalg/checkutil/NDArrayCreationUtil.java | 4 +- .../linalg/dataset/AsyncDataSetIterator.java | 2 +- .../dataset/AsyncMultiDataSetIterator.java | 2 +- .../java/org/nd4j/linalg/dataset/DataSet.java | 1 - .../dataset/api/iterator/KFoldIterator.java | 3 +- .../api/iterator/TestDataSetIterator.java | 1 - .../RandomProjection.java | 4 +- .../nd4j/linalg/env/EnvironmentalAction.java | 1 - .../linalg/factory/BaseNDArrayFactory.java | 1 - .../org/nd4j/linalg/factory/BlasWrapper.java | 24 +- .../nd4j/linalg/factory/NDArrayFactory.java | 10 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 27 +- .../org/nd4j/linalg/factory/Nd4jBackend.java | 6 +- .../org/nd4j/linalg/factory/ops/NDBase.java | 1998 +-- .../org/nd4j/linalg/indexing/Indices.java | 14 +- .../linalg/learning/AdaBeliefUpdater.java | 1 - .../nd4j/linalg/learning/AdaDeltaUpdater.java | 128 +- .../nd4j/linalg/learning/AdaMaxUpdater.java | 1 - .../org/nd4j/linalg/learning/AdamUpdater.java | 1 - .../nd4j/linalg/learning/GradientUpdater.java | 1 - .../nd4j/linalg/learning/NadamUpdater.java | 1 - .../linalg/learning/NesterovsUpdater.java | 1 - .../collection/MultiDimensionalMap.java | 10 +- .../collection/MultiDimensionalSet.java | 16 +- .../java/org/nd4j/common/util/ArrayUtil.java | 4 +- .../deeplearning4j/nn/layers/HelperUtils.java | 6 +- cavis-full/build.gradle | 5 +- cavis-native/cavis-native-lib/build.gradle | 2 +- .../collection/MultiDimensionalMap.java | 10 +- .../collection/MultiDimensionalSet.java | 16 +- .../java/org/nd4j/common/util/ArrayUtil.java | 4 +- settings.gradle | 1 - 67 files changed, 9896 insertions(+), 8781 deletions(-) diff --git a/build.gradle b/build.gradle index cd5911461..fc9167f30 100644 --- a/build.gradle +++ b/build.gradle @@ -56,6 +56,7 @@ configurations.all { } + allprojects { Project proj -> apply plugin: 'com.google.osdetector' @@ -162,3 +163,21 @@ allprojects { Project proj -> } } } + + +task aggregatedJavadocs(type: Javadoc, description: 'Generate javadocs from all child projects as if it was a single project', group: 'Documentation') { + subprojects.each { proj -> + proj.tasks.withType(Javadoc).each { javadocTask -> + logger.quiet("Adding javadoc for project " + proj.name) + source += javadocTask.source + classpath += javadocTask.classpath + excludes += javadocTask.excludes + includes += javadocTask.includes + } + } + destinationDir = file("$buildDir/docs/javadoc") + title = "$project.name $version API" + options.author true + options.links 'http://docs.oracle.com/javase/8/docs/api/' + options.addStringOption('Xdoclint:none', '-quiet') +} \ No newline at end of file diff --git a/build_requirements.md b/build_requirements.md index 602190b95..77d54050b 100644 --- a/build_requirements.md +++ b/build_requirements.md @@ -141,4 +141,8 @@ groupId:artifactId:packaging:classifier:version In your case it should work with -edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 \ No newline at end of file +edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 + + +Native cpu code under linux needs libc6-dev +/lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java index 2ea351b38..648581291 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java @@ -90,7 +90,7 @@ public class Operands { /** * This method returns array identified its numeric id - * @param name + * @param id * @return */ public INDArray getById(int id) { @@ -99,7 +99,8 @@ public class Operands { /** * This method returns array identified its numeric id and index - * @param name + * @param id + * @param index * @return */ public INDArray getById(int id, int index) { @@ -121,7 +122,7 @@ public class Operands { } /** - * This method returns contents of this entity as collection of key->value pairs + * This method returns contents of this entity as collection of key->value pairs * @return */ public Collection> asCollection() { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java index ab4423020..748a8a6b8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java @@ -50,7 +50,7 @@ public class ExecDebuggingListener extends BaseListener { /** * @param printMode Print mode, see {@link PrintMode} - * @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations" + * @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations" * @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output */ public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index dc064e515..1584b5977 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -573,7 +573,7 @@ public class SameDiff extends SDBaseOps { } /** - * Get the function by the {@link DifferentialFunction#getOwnName()} + * Get the function by the {@link org.nd4j.autodiff.functions.DifferentialFunction#getOwnName()} * * @param id the id of the function * @return the function for the given id if it exists @@ -1348,9 +1348,9 @@ public class SameDiff extends SDBaseOps { /** * Get the names of variables (if any) that have been marked as loss variables to be minimized.
* Variables can be marked as loss variables in a few different ways:
- * (a) Losses are automatically added when creating loss functions via {@link #sd()}
+ * (a) Losses are automatically added when creating loss functions via {@link SameDiff#sd}
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}
- * (c) Via {@link TrainingConfig#setLossVariables(List)}
+ * (c) Via {@link org.nd4j.autodiff.samediff.TrainingConfig#setLossVariables(List)}
*/ public List getLossVariables() { return Collections.unmodifiableList(this.lossVariables); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java index 9c3b5d917..7a3d8b01a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java @@ -54,12 +54,12 @@ public abstract class AbstractDependencyTracker { } /** - * @return A new map where the dependents (i.e., Y in "X -> Y") are the key + * @return A new map where the dependents (i.e., Y in "X -> Y") are the key */ protected abstract Map newTMap(); /** - * @return A new set where the dependents (i.e., Y in "X -> Y") are the key + * @return A new set where the dependents (i.e., Y in "X -> Y") are the key */ protected abstract Set newTSet(); @@ -103,7 +103,7 @@ public abstract class AbstractDependencyTracker { /** * Mark the specified value as satisfied. - * For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true) + * For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true) * call, both of these dependencies are considered satisfied. * * @param x Value to mark @@ -191,7 +191,7 @@ public abstract class AbstractDependencyTracker { } /** - * Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)} + * Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)} * or {@link #addOrDependency(Object, Object, Object)} * * @param y Dependent to check @@ -207,7 +207,7 @@ public abstract class AbstractDependencyTracker { } /** - * Get all dependencies x, for x -> y, and (x1 or x2) -> y + * Get all dependencies x, for x -> y, and (x1 or x2) -> y * * @param y Dependent to get dependencies for * @return List of dependencies @@ -223,7 +223,7 @@ public abstract class AbstractDependencyTracker { } /** - * Add a dependency: y depends on x, as in x -> y + * Add a dependency: y depends on x, as in x -> y * * @param y The dependent * @param x The dependee that is required for Y @@ -302,7 +302,7 @@ public abstract class AbstractDependencyTracker { /** - * Remove a dependency (x -> y) + * Remove a dependency (x -> y) * * @param y The dependent that currently requires X * @param x The dependee that is no longer required for Y @@ -357,7 +357,7 @@ public abstract class AbstractDependencyTracker { } /** - * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y
+ * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the * dependency is considered satisfied * @@ -382,16 +382,16 @@ public abstract class AbstractDependencyTracker { } /** - * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) + * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) */ public boolean hasNewAllSatisfied() { return !allSatisfiedQueue.isEmpty(); } /** - * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} + * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} * Throws an exception if {@link #hasNewAllSatisfied()} returns false.
- * Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; + * Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; * the value is considered "processed" at this point. * * @return The next new "all satisfied dependent" diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index d00efcba7..ce29242a8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -487,7 +487,7 @@ public abstract class AbstractSession { } /** - * Add the control dependency from Op -> variable + * Add the control dependency from Op -> variable * * @param es Execution step for the variable * @param v Variable @@ -542,7 +542,7 @@ public abstract class AbstractSession { /** * Update the descendant dependencies - * So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker + * So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker * This is for a specific frame and iteration, for both sides of the dependency (in and out) * * @param justExecuted The execution step that has just completed @@ -621,7 +621,7 @@ public abstract class AbstractSession { /** * Suppose operation X has just been executed. - * For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp + * For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp * (which includes X, but may not only be X) * * @param opName Name of the op diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java index 00102c498..38e99641b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -28,15 +28,15 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; public class SDBitwise extends SDOps { + public SDBitwise(SameDiff sameDiff) { super(sameDiff); } /** * Bitwise AND operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param x First input array (INT type) @@ -47,147 +47,155 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("and", "x", x); SDValidation.validateInteger("and", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd, x, y).outputVariable(); } /** * Bitwise AND operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array (INT type) - * @param y Second input array (INT type) + * @param x First input array (INT type) + * @param y Second input array (INT type) * @return output Bitwise AND array (INT type) */ public SDVariable and(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("and", "x", x); SDValidation.validateInteger("and", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotl(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotl", "x", x); SDValidation.validateInteger("bitRotl", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); } /** - * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotl", "x", x); SDValidation.validateInteger("bitRotl", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotr(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotr", "x", x); SDValidation.validateInteger("bitRotr", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); } /** - * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotr", "x", x); SDValidation.validateInteger("bitRotr", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Shift integer bits to the left, i.e. var << 4
+ * Shift integer bits to the left, i.e. {@code var << 4}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShift(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShift", "x", x); SDValidation.validateInteger("bitShift", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); } /** - * Shift integer bits to the left, i.e. var << 4
+ * Shift integer bits to the left, i.e. {@code var << 4}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShift", "x", x); SDValidation.validateInteger("bitShift", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Shift integer bits to the right, i.e. var >> 4
+ * Shift integer bits to the right, i.e. {@code var >> 4}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShiftRight", "x", x); SDValidation.validateInteger("bitShiftRight", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); } /** - * Shift integer bits to the right, i.e. var >> 4
+ * Shift integer bits to the right, i.e. {@code var >> 4}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShiftRight", "x", x); SDValidation.validateInteger("bitShiftRight", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bitwise Hamming distance reduction over all elements of both input arrays.
- * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ * Bitwise Hamming distance reduction over all elements of both input arrays.
For example, if + * x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at + * positions 0 and 1)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* * @param x First input array. (INT type) * @param y Second input array. (INT type) @@ -197,26 +205,28 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("bitsHammingDistance", "x", x); SDValidation.validateInteger("bitsHammingDistance", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x, + y).outputVariable(); } /** - * Bitwise Hamming distance reduction over all elements of both input arrays.
- * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ * Bitwise Hamming distance reduction over all elements of both input arrays.
For example, if + * x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at + * positions 0 and 1)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array. (INT type) - * @param y Second input array. (INT type) + * @param x First input array. (INT type) + * @param y Second input array. (INT type) * @return output bitwise Hamming distance (INT type) */ public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("bitsHammingDistance", "x", x); SDValidation.validateInteger("bitsHammingDistance", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -230,27 +240,28 @@ public class SDBitwise extends SDOps { public SDVariable leftShift(SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShift", "x", x); SDValidation.validateInteger("leftShift", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, y).outputVariable(); } /** * Bitwise left shift operation. Supports broadcasting.
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise shifted input x (INT type) */ public SDVariable leftShift(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShift", "x", x); SDValidation.validateInteger("leftShift", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bitwise left cyclical shift operation. Supports broadcasting.
- * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise left cyclical shift operation. Supports broadcasting.
Unlike + * {@link SDBitwise#leftShift(INDArray, INDArray)} the bits will "wrap around":
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
* * @param x Input to be bit shifted (INT type) @@ -260,31 +271,32 @@ public class SDBitwise extends SDOps { public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShiftCyclic", "x", x); SDValidation.validateInteger("leftShiftCyclic", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + y).outputVariable(); } /** - * Bitwise left cyclical shift operation. Supports broadcasting.
- * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise left cyclical shift operation. Supports broadcasting.
Unlike + * {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise cyclic shifted input x (INT type) */ public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShiftCyclic", "x", x); SDValidation.validateInteger("leftShiftCyclic", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Bitwise OR operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param x First input array (INT type) @@ -295,26 +307,26 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("or", "x", x); SDValidation.validateInteger("or", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd, x, y).outputVariable(); } /** * Bitwise OR operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array (INT type) - * @param y First input array (INT type) + * @param x First input array (INT type) + * @param y First input array (INT type) * @return output Bitwise OR array (INT type) */ public SDVariable or(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("or", "x", x); SDValidation.validateInteger("or", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -328,27 +340,28 @@ public class SDBitwise extends SDOps { public SDVariable rightShift(SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShift", "x", x); SDValidation.validateInteger("rightShift", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, y).outputVariable(); } /** * Bitwise right shift operation. Supports broadcasting.
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise shifted input x (INT type) */ public SDVariable rightShift(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShift", "x", x); SDValidation.validateInteger("rightShift", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bitwise right cyclical shift operation. Supports broadcasting.
- * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise right cyclical shift operation. Supports broadcasting.
Unlike + * {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
* * @param x Input to be bit shifted (INT type) @@ -358,31 +371,32 @@ public class SDBitwise extends SDOps { public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShiftCyclic", "x", x); SDValidation.validateInteger("rightShiftCyclic", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + y).outputVariable(); } /** - * Bitwise right cyclical shift operation. Supports broadcasting.
- * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise right cyclical shift operation. Supports broadcasting.
Unlike + * {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise cyclic shifted input x (INT type) */ public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShiftCyclic", "x", x); SDValidation.validateInteger("rightShiftCyclic", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Bitwise XOR operation (exclusive OR). Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param x First input array (INT type) @@ -393,26 +407,26 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("xor", "x", x); SDValidation.validateInteger("xor", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd, x, y).outputVariable(); } /** * Bitwise XOR operation (exclusive OR). Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array (INT type) - * @param y First input array (INT type) + * @param x First input array (INT type) + * @param y First input array (INT type) * @return output Bitwise XOR array (INT type) */ public SDVariable xor(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("xor", "x", x); SDValidation.validateInteger("xor", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 6317e0941..558d095db 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -355,7 +355,8 @@ public class SDImage extends SDOps { * @param maxOutSize scalar representing the maximum number of boxes to be selected * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU * @param scoreThreshold threshold for deciding when to remove boxes based on score - * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, + * where M <= max_output_size (NUMERIC type) */ public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize, double iouThreshold, double scoreThreshold) { @@ -373,7 +374,8 @@ public class SDImage extends SDOps { * @param maxOutSize scalar representing the maximum number of boxes to be selected * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU * @param scoreThreshold threshold for deciding when to remove boxes based on score - * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, + * where M <= max_output_size (NUMERIC type) */ public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores, int maxOutSize, double iouThreshold, double scoreThreshold) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index c6fef378e..5f6b76c94 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -26,6 +26,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; public class SDLoss extends SDOps { + public SDLoss(SameDiff sameDiff) { super(sameDiff); } @@ -33,10 +34,11 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, @@ -44,7 +46,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } @@ -52,11 +55,12 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, @@ -64,7 +68,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } @@ -72,9 +77,9 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, @@ -82,7 +87,9 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } @@ -90,10 +97,10 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, @@ -101,23 +108,28 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, @@ -125,24 +137,28 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, lossReduce, dimension).outputVariable(); out.markAsLoss(); return out; } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, @@ -150,22 +166,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, lossReduce, dimension).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, @@ -173,23 +192,27 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + dimension).outputVariable(); out.markAsLoss(); return out; } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, @@ -197,20 +220,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + dimension).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -218,21 +246,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, @@ -240,39 +272,45 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights) { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, @@ -280,25 +318,27 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*


* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -306,26 +346,28 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, lossReduce, delta).outputVariable(); out.markAsLoss(); return out; } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*

* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, @@ -333,24 +375,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, lossReduce, delta).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*

* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -358,25 +401,27 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + delta).outputVariable(); out.markAsLoss(); return out; } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*

* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, @@ -384,7 +429,9 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + delta).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } @@ -397,7 +444,7 @@ public class SDLoss extends SDOps { */ public SDVariable l2Loss(SDVariable var) { SDValidation.validateNumerical("l2Loss", "var", var); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd, var).outputVariable(); out.markAsLoss(); return out; } @@ -406,25 +453,28 @@ public class SDLoss extends SDOps { * L2 loss: 1/2 * sum(x^2)
* * @param name name May be null. Name for the output variable - * @param var Variable to calculate L2 loss of (NUMERIC type) + * @param var Variable to calculate L2 loss of (NUMERIC type) * @return output L2 loss (NUMERIC type) */ public SDVariable l2Loss(String name, SDVariable var) { SDValidation.validateNumerical("l2Loss", "var", var); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd, var).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param epsilon epsilon + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -432,21 +482,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); SDValidation.validateNumerical("logLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, weights, + lossReduce, epsilon).outputVariable(); out.markAsLoss(); return out; } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param epsilon epsilon + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(String name, SDVariable label, SDVariable predictions, @@ -454,53 +508,61 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); SDValidation.validateNumerical("logLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, weights, + lossReduce, epsilon).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(SDVariable label, SDVariable predictions) { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, null, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); out.markAsLoss(); return out; } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(String name, SDVariable label, SDVariable predictions) { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, null, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, @@ -508,21 +570,23 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, lossReduce, full).outputVariable(); out.markAsLoss(); return out; } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, @@ -530,19 +594,20 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, lossReduce, full).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, @@ -550,20 +615,22 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + full).outputVariable(); out.markAsLoss(); return out; } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, @@ -571,21 +638,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + full).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, @@ -593,22 +664,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, @@ -616,20 +690,22 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, @@ -637,21 +713,24 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, @@ -659,20 +738,24 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, @@ -680,21 +763,24 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, @@ -702,39 +788,44 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights) { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, @@ -742,30 +833,35 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, @@ -773,31 +869,35 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return out; } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, @@ -805,28 +905,31 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, @@ -834,29 +937,33 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return out; } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, @@ -864,28 +971,33 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, @@ -893,29 +1005,33 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return out; } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param name name May be null. Name for the output variable - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, @@ -924,26 +1040,29 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, @@ -951,27 +1070,31 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return out; } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param name name May be null. Name for the output variable - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, @@ -979,14 +1102,16 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
- * is represented as an integer array instead of the equivalent one-hot array.
+ * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels + * variable
is represented as an integer array instead of the equivalent one-hot array.
* i.e., if logits are rank N, then labels have rank N-1
* * @param logits Logits array ("pre-softmax activations") (NUMERIC type) @@ -996,17 +1121,18 @@ public class SDLoss extends SDOps { public SDVariable sparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels) { SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits( + sd, logits, labels).outputVariable(); out.markAsLoss(); return out; } /** - * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
- * is represented as an integer array instead of the equivalent one-hot array.
+ * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels + * variable
is represented as an integer array instead of the equivalent one-hot array.
* i.e., if logits are rank N, then labels have rank N-1
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param logits Logits array ("pre-softmax activations") (NUMERIC type) * @param labels Labels array. Must be an integer type. (INT type) * @return output Softmax cross entropy (NUMERIC type) @@ -1014,7 +1140,8 @@ public class SDLoss extends SDOps { public SDVariable sparseSoftmaxCrossEntropy(String name, SDVariable logits, SDVariable labels) { SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits( + sd, logits, labels).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } @@ -1023,7 +1150,7 @@ public class SDLoss extends SDOps { * Weighted cross entropy loss with logits
* * @param targets targets array (NUMERIC type) - * @param inputs input array (NUMERIC type) + * @param inputs input array (NUMERIC type) * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ @@ -1032,7 +1159,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd, targets, + inputs, weights).outputVariable(); out.markAsLoss(); return out; } @@ -1040,9 +1168,9 @@ public class SDLoss extends SDOps { /** * Weighted cross entropy loss with logits
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param targets targets array (NUMERIC type) - * @param inputs input array (NUMERIC type) + * @param inputs input array (NUMERIC type) * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ @@ -1051,7 +1179,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd, targets, + inputs, weights).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 5c3579396..bbef06cfb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.indexing.conditions.Condition; public class SDMath extends SDOps { + public SDMath(SameDiff sameDiff) { super(sameDiff); } @@ -36,53 +37,60 @@ public class SDMath extends SDOps { /** * Clips tensor values to a maximum average L2-norm.
* - * @param x Input variable (NUMERIC type) - * @param clipValue Value for clipping + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByAvgNorm(SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("ClipByAvgNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd, x, clipValue, + dimensions).outputVariable(); } /** * Clips tensor values to a maximum average L2-norm.
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param clipValue Value for clipping + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByAvgNorm(String name, SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("ClipByAvgNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd, x, + clipValue, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Looks up ids in a list of embedding tensors.
* - * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); SDValidation.validateInteger("EmbeddingLookup", "indices", indices); - return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd, x, indices, + PartitionMode).outputVariable(); } /** * Looks up ids in a list of embedding tensors.
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ @@ -90,35 +98,39 @@ public class SDMath extends SDOps { PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); SDValidation.validateInteger("EmbeddingLookup", "indices", indices); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd, x, + indices, PartitionMode).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Return array of max elements indices with along tensor dimensions
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param dataType Data type * @return output Array max elements indices with along dimensions. (INT type) */ public SDVariable mergeMaxIndex(SDVariable[] x, DataType dataType) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, dataType).outputVariable(); } /** * Return array of max elements indices with along tensor dimensions
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param dataType Data type * @return output Array max elements indices with along dimensions. (INT type) */ public SDVariable mergeMaxIndex(String name, SDVariable[] x, DataType dataType) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, + dataType).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -130,21 +142,25 @@ public class SDMath extends SDOps { */ public SDVariable mergeMaxIndex(SDVariable... x) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, + DataType.INT).outputVariable(); } /** * Return array of max elements indices with along tensor dimensions
* * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @return output Array max elements indices with along dimensions. (INT type) */ public SDVariable mergeMaxIndex(String name, SDVariable... x) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, + DataType.INT).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -156,19 +172,19 @@ public class SDMath extends SDOps { */ public SDVariable abs(SDVariable x) { SDValidation.validateNumerical("abs", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd, x).outputVariable(); } /** * Elementwise absolute value operation: out = abs(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable abs(String name, SDVariable x) { SDValidation.validateNumerical("abs", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -180,19 +196,20 @@ public class SDMath extends SDOps { */ public SDVariable acos(SDVariable x) { SDValidation.validateNumerical("acos", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd, x).outputVariable(); } /** * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable acos(String name, SDVariable x) { SDValidation.validateNumerical("acos", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -204,28 +221,30 @@ public class SDMath extends SDOps { */ public SDVariable acosh(SDVariable x) { SDValidation.validateNumerical("acosh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd, x).outputVariable(); } /** * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable acosh(String name, SDVariable x) { SDValidation.validateNumerical("acosh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise addition operation, out = x + y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -234,169 +253,204 @@ public class SDMath extends SDOps { public SDVariable add(SDVariable x, SDVariable y) { SDValidation.validateNumerical("add", "x", x); SDValidation.validateNumerical("add", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd, x, + y).outputVariable(); } /** * Pairwise addition operation, out = x + y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable add(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("add", "x", x); SDValidation.validateNumerical("add", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar add operation, out = in + scalar
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable add(SDVariable x, double value) { SDValidation.validateNumerical("add", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd, x, value).outputVariable(); } /** * Scalar add operation, out = in + scalar
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable add(String name, SDVariable x, double value) { SDValidation.validateNumerical("add", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * Absolute max array reduction operation, optionally along specified dimensions: out = + * max(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amax(SDVariable in, int... dimensions) { SDValidation.validateNumerical("amax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd, in, dimensions).outputVariable(); } /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * Absolute max array reduction operation, optionally along specified dimensions: out = + * max(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amax(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("amax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * Absolute mean array reduction operation, optionally along specified dimensions: out = + * mean(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amean(SDVariable in, int... dimensions) { SDValidation.validateNumerical("amean", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd, in, + dimensions).outputVariable(); } /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * Absolute mean array reduction operation, optionally along specified dimensions: out = + * mean(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amean(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("amean", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * Absolute min array reduction operation, optionally along specified dimensions: out = + * min(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amin(SDVariable in, int... dimensions) { SDValidation.validateNumerical("amin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd, in, dimensions).outputVariable(); } /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * Absolute min array reduction operation, optionally along specified dimensions: out = + * min(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amin(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("amin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean AND operation: {@code elementwise (x != 0) && (y != 0)}
If x and y arrays have + * equal shape, the output shape is the same as these inputs.
Note: supports broadcasting if x + * and y have different shapes and are broadcastable.
Returns an array with values 1 where + * condition is satisfied, or value 0 otherwise.
* * @param x Input 1 (BOOL type) * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable and(SDVariable x, SDVariable y) { SDValidation.validateBool("and", "x", x); SDValidation.validateBool("and", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd, x, y).outputVariable(); } /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean AND operation: {@code elementwise (x != 0) && (y != 0)}
If x and y arrays have + * equal shape, the output shape is the same as these inputs.
Note: supports broadcasting if x + * and y have different shapes and are broadcastable.
Returns an array with values 1 where + * condition is satisfied, or value 0 otherwise.
* * @param name name May be null. Name for the output variable - * @param x Input 1 (BOOL type) - * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable and(String name, SDVariable x, SDVariable y) { SDValidation.validateBool("and", "x", x); SDValidation.validateBool("and", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -408,19 +462,20 @@ public class SDMath extends SDOps { */ public SDVariable asin(SDVariable x) { SDValidation.validateNumerical("asin", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd, x).outputVariable(); } /** * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable asin(String name, SDVariable x) { SDValidation.validateNumerical("asin", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -432,47 +487,57 @@ public class SDMath extends SDOps { */ public SDVariable asinh(SDVariable x) { SDValidation.validateNumerical("asinh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd, x).outputVariable(); } /** * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable asinh(String name, SDVariable x) { SDValidation.validateNumerical("asinh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * Absolute sum array reduction operation, optionally along specified dimensions: out = + * sum(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable asum(SDVariable in, int... dimensions) { SDValidation.validateNumerical("asum", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd, in, dimensions).outputVariable(); } /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * Absolute sum array reduction operation, optionally along specified dimensions: out = + * sum(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable asum(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("asum", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -484,25 +549,26 @@ public class SDMath extends SDOps { */ public SDVariable atan(SDVariable x) { SDValidation.validateNumerical("atan", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd, x).outputVariable(); } /** * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable atan(String name, SDVariable x) { SDValidation.validateNumerical("atan", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
- * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
Similar to + * atan(y/x) but sigts of x and y are used to determine the location of the result
* * @param y Input Y variable (NUMERIC type) * @param x Input X variable (NUMERIC type) @@ -511,22 +577,23 @@ public class SDMath extends SDOps { public SDVariable atan2(SDVariable y, SDVariable x) { SDValidation.validateNumerical("atan2", "y", y); SDValidation.validateNumerical("atan2", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd, y, x).outputVariable(); } /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
- * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
Similar to + * atan(y/x) but sigts of x and y are used to determine the location of the result
* * @param name name May be null. Name for the output variable - * @param y Input Y variable (NUMERIC type) - * @param x Input X variable (NUMERIC type) + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable atan2(String name, SDVariable y, SDVariable x) { SDValidation.validateNumerical("atan2", "y", y); SDValidation.validateNumerical("atan2", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd, y, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -538,220 +605,236 @@ public class SDMath extends SDOps { */ public SDVariable atanh(SDVariable x) { SDValidation.validateNumerical("atanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd, x).outputVariable(); } /** * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable atanh(String name, SDVariable x) { SDValidation.validateNumerical("atanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Bit shift operation
* - * @param x input (NUMERIC type) + * @param x input (NUMERIC type) * @param shift shift value (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShift(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShift", "x", x); SDValidation.validateNumerical("bitShift", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); } /** * Bit shift operation
* - * @param name name May be null. Name for the output variable - * @param x input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x input (NUMERIC type) * @param shift shift value (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShift", "x", x); SDValidation.validateNumerical("bitShift", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Right bit shift operation
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param shift shift argument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRight", "x", x); SDValidation.validateNumerical("bitShiftRight", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); } /** * Right bit shift operation
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param shift shift argument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRight", "x", x); SDValidation.validateNumerical("bitShiftRight", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Cyclic bit shift operation
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param shift shift argy=ument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRotl(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotl", "x", x); SDValidation.validateNumerical("bitShiftRotl", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); } /** * Cyclic bit shift operation
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param shift shift argy=ument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRotl(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotl", "x", x); SDValidation.validateNumerical("bitShiftRotl", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Cyclic right shift operation
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param shift Shift argument (NUMERIC type) * @return output Shifted output (NUMERIC type) */ public SDVariable bitShiftRotr(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotr", "x", x); SDValidation.validateNumerical("bitShiftRotr", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); } /** * Cyclic right shift operation
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param shift Shift argument (NUMERIC type) * @return output Shifted output (NUMERIC type) */ public SDVariable bitShiftRotr(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotr", "x", x); SDValidation.validateNumerical("bitShiftRotr", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise ceiling function: out = ceil(x).
- * Rounds each value up to the nearest integer value (if not already an integer)
+ * Element-wise ceiling function: out = ceil(x).
Rounds each value up to the nearest integer + * value (if not already an integer)
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable ceil(SDVariable x) { SDValidation.validateNumerical("ceil", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd, x).outputVariable(); } /** - * Element-wise ceiling function: out = ceil(x).
- * Rounds each value up to the nearest integer value (if not already an integer)
+ * Element-wise ceiling function: out = ceil(x).
Rounds each value up to the nearest integer + * value (if not already an integer)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable ceil(String name, SDVariable x) { SDValidation.validateNumerical("ceil", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
- * to the corresponding l2Norm along the specified dimensions
+ * Clipping by L2 norm, optionally along dimension(s)
if l2Norm(x,dimension) < clipValue, + * then input is returned unmodifed
Otherwise, out[i] = in[i] * clipValue / l2Norm(in, + * dimensions) where each value is clipped according
to the corresponding l2Norm along the + * specified dimensions
* - * @param x Input variable (NUMERIC type) - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("clipByNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd, x, clipValue, + dimensions).outputVariable(); } /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
- * to the corresponding l2Norm along the specified dimensions
+ * Clipping by L2 norm, optionally along dimension(s)
if l2Norm(x,dimension) < clipValue, + * then input is returned unmodifed
Otherwise, out[i] = in[i] * clipValue / l2Norm(in, + * dimensions) where each value is clipped according
to the corresponding l2Norm along the + * specified dimensions
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("clipByNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd, x, clipValue, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
+ * Element-wise clipping function:
out[i] = in[i] if in[i] >= clipValueMin and in[i] <= + * clipValueMax
out[i] = clipValueMin if in[i] < clipValueMin
out[i] = clipValueMax if + * in[i] > clipValueMax
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param clipValueMin Minimum value for clipping * @param clipValueMax Maximum value for clipping * @return output Output variable (NUMERIC type) */ public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { SDValidation.validateNumerical("clipByValue", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd, x, clipValueMin, + clipValueMax).outputVariable(); } /** * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
+ * {@code out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax out[i] = clipValueMin + * if in[i] < clipValueMin out[i] = clipValueMax if in[i] > clipValueMax} * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param clipValueMin Minimum value for clipping * @param clipValueMax Maximum value for clipping * @return output Output variable (NUMERIC type) @@ -759,40 +842,40 @@ public class SDMath extends SDOps { public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) { SDValidation.validateNumerical("clipByValue", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd, x, + clipValueMin, clipValueMax).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
- * [1, 0, 0]
- * [0, 1, 1]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1] and + * predicted = [0, 2, 1] then output is:
[1, 0, 0]
[0, 1, 1]
[0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) * @param dataType Data type * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + dataType).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
- * [1, 0, 0]
- * [0, 1, 1]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1] and + * predicted = [0, 2, 1] then output is:
[1, 0, 0]
[0, 1, 1]
[0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) * @param dataType Data type * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ @@ -800,42 +883,40 @@ public class SDMath extends SDOps { DataType dataType) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + dataType).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
- * [1, 0, 0, 0]
- * [0, 1, 1, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
[1, 0, 0, 0]
[0, 1, + * 1, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) * @param numClasses Number of classes * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, int numClasses) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + numClasses).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
- * [1, 0, 0, 0]
- * [0, 1, 1, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
[1, 0, 0, 0]
[0, 1, + * 1, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) * @param numClasses Number of classes * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ @@ -843,42 +924,46 @@ public class SDMath extends SDOps { int numClasses) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + numClasses).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
- * [1, 0, 0]
- * [0, 3, 2]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1], + * predicted = [0, 2, 1] and weights = [1, 2, 3]
[1, 0, 0]
[0, 3, 2]
[0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels and + * predictions arrays (NUMERIC type) * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + weights).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
- * [1, 0, 0]
- * [0, 3, 2]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1], + * predicted = [0, 2, 1] and weights = [1, 2, 3]
[1, 0, 0]
[0, 3, 2]
[0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels and + * predictions arrays (NUMERIC type) * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, @@ -886,23 +971,24 @@ public class SDMath extends SDOps { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + weights).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
- * [1, 0, 0, 0]
- * [0, 3, 2, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
[1, 0, 0, 0]
+ * [0, 3, 2, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) - * @param numClasses + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels + * and predictions arrays (NUMERIC type) + * @param numClasses * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights, @@ -910,23 +996,24 @@ public class SDMath extends SDOps { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, weights, + numClasses).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
- * [1, 0, 0, 0]
- * [0, 3, 2, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
[1, 0, 0, 0]
+ * [0, 3, 2, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) - * @param numClasses + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels + * and predictions arrays (NUMERIC type) + * @param numClasses * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, @@ -934,7 +1021,8 @@ public class SDMath extends SDOps { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + weights, numClasses).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -946,19 +1034,19 @@ public class SDMath extends SDOps { */ public SDVariable cos(SDVariable x) { SDValidation.validateNumerical("cos", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd, x).outputVariable(); } /** * Elementwise cosine operation: out = cos(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cos(String name, SDVariable x) { SDValidation.validateNumerical("cos", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -970,153 +1058,185 @@ public class SDMath extends SDOps { */ public SDVariable cosh(SDVariable x) { SDValidation.validateNumerical("cosh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd, x).outputVariable(); } /** * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cosh(String name, SDVariable x) { SDValidation.validateNumerical("cosh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Cosine distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = 1.0 - cosineSimilarity(x,y)
+ * tensor/subset along the specified dimensions:
out = 1.0 - cosineSimilarity(x,y)
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineDistance", "x", x); SDValidation.validateNumerical("cosineDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd, x, y, + dimensions).outputVariable(); } /** * Cosine distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = 1.0 - cosineSimilarity(x,y)
+ * tensor/subset along the specified dimensions:
out = 1.0 - cosineSimilarity(x,y)
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineDistance", "x", x); SDValidation.validateNumerical("cosineDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
- * along the specified dimensions:
- * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for + * each tensor/subset
along the specified dimensions:
out = (sum_i x[i] * y[i]) / ( + * sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineSimilarity", "x", x); SDValidation.validateNumerical("cosineSimilarity", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd, x, y, + dimensions).outputVariable(); } /** - * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
- * along the specified dimensions:
- * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for + * each tensor/subset
along the specified dimensions:
out = (sum_i x[i] * y[i]) / ( + * sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineSimilarity", "x", x); SDValidation.validateNumerical("cosineSimilarity", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * Count non zero array reduction operation, optionally along specified dimensions: out = count(x + * != 0)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countNonZero(SDVariable in, int... dimensions) { SDValidation.validateNumerical("countNonZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd, in, + dimensions).outputVariable(); } /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * Count non zero array reduction operation, optionally along specified dimensions: out = count(x + * != 0)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countNonZero(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("countNonZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * Count zero array reduction operation, optionally along specified dimensions: out = count(x == + * 0)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countZero(SDVariable in, int... dimensions) { SDValidation.validateNumerical("countZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd, in, + dimensions).outputVariable(); } /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * Count zero array reduction operation, optionally along specified dimensions: out = count(x == + * 0)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countZero(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("countZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
- * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| + * sin(theta).
Can take rank 1 or above inputs (of equal shapes), but note that the last + * dimension must have dimension 3
* * @param a First input (NUMERIC type) * @param b Second input (NUMERIC type) @@ -1125,22 +1245,23 @@ public class SDMath extends SDOps { public SDVariable cross(SDVariable a, SDVariable b) { SDValidation.validateNumerical("cross", "a", a); SDValidation.validateNumerical("cross", "b", b); - return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd, a, b).outputVariable(); } /** - * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
- * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| + * sin(theta).
Can take rank 1 or above inputs (of equal shapes), but note that the last + * dimension must have dimension 3
* * @param name name May be null. Name for the output variable - * @param a First input (NUMERIC type) - * @param b Second input (NUMERIC type) + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) * @return output Element-wise cross product (NUMERIC type) */ public SDVariable cross(String name, SDVariable a, SDVariable b) { SDValidation.validateNumerical("cross", "a", a); SDValidation.validateNumerical("cross", "b", b); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd, a, b).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1152,102 +1273,91 @@ public class SDMath extends SDOps { */ public SDVariable cube(SDVariable x) { SDValidation.validateNumerical("cube", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd, x).outputVariable(); } /** * Element-wise cube function: out = x^3
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cube(String name, SDVariable x) { SDValidation.validateNumerical("cube", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
- * For example, if input = [1,2,3], then output is given by:
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
+ * Returns an output variable with diagonal values equal to the specified values; off-diagonal + * values will be set to 0
For example, if input = [1,2,3], then output is given by:
[ 1, + * 0, 0]
[ 0, 2, 0]
[ 0, 0, 3]
*
- * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
- * i.e., for input rank R, output has rank 2R
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then + * output[i,...,k,i,...,k] = input[i,...,k].
i.e., for input rank R, output has rank 2R
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable diag(SDVariable x) { SDValidation.validateNumerical("diag", "x", x); - return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd, x).outputVariable(); } /** - * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
- * For example, if input = [1,2,3], then output is given by:
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
+ * Returns an output variable with diagonal values equal to the specified values; off-diagonal + * values will be set to 0
For example, if input = [1,2,3], then output is given by:
[ 1, + * 0, 0]
[ 0, 2, 0]
[ 0, 0, 3]
*
- * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
- * i.e., for input rank R, output has rank 2R
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then + * output[i,...,k,i,...,k] = input[i,...,k].
i.e., for input rank R, output has rank 2R
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable diag(String name, SDVariable x) { SDValidation.validateNumerical("diag", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Extract the diagonal part from the input array.
- * If input is
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- * then output is [1, 2, 3].
- * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * Extract the diagonal part from the input array.
If input is
[ 1, 0, 0]
[ 0, 2, + * 0]
[ 0, 0, 3]
then output is [1, 2, 3].
Supports higher dimensions: in general, + * out[i,...,k] = in[i,...,k,i,...,k]
* * @param x Input variable (NUMERIC type) * @return output Diagonal part of the input (NUMERIC type) */ public SDVariable diagPart(SDVariable x) { SDValidation.validateNumerical("diagPart", "x", x); - return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd, x).outputVariable(); } /** - * Extract the diagonal part from the input array.
- * If input is
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- * then output is [1, 2, 3].
- * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * Extract the diagonal part from the input array.
If input is
[ 1, 0, 0]
[ 0, 2, + * 0]
[ 0, 0, 3]
then output is [1, 2, 3].
Supports higher dimensions: in general, + * out[i,...,k] = in[i,...,k,i,...,k]
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Diagonal part of the input (NUMERIC type) */ public SDVariable diagPart(String name, SDVariable x) { SDValidation.validateNumerical("diagPart", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise division operation, out = x / y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -1256,79 +1366,91 @@ public class SDMath extends SDOps { public SDVariable div(SDVariable x, SDVariable y) { SDValidation.validateNumerical("div", "x", x); SDValidation.validateNumerical("div", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd, x, + y).outputVariable(); } /** * Pairwise division operation, out = x / y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable div(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("div", "x", x); SDValidation.validateNumerical("div", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar division operation, out = in / scalar
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable div(SDVariable x, double value) { SDValidation.validateNumerical("div", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd, x, value).outputVariable(); } /** * Scalar division operation, out = in / scalar
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable div(String name, SDVariable x, double value) { SDValidation.validateNumerical("div", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Entropy reduction: -sum(x * log(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable entropy(SDVariable in, int... dimensions) { SDValidation.validateNumerical("entropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd, in, + dimensions).outputVariable(); } /** * Entropy reduction: -sum(x * log(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable entropy(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("entropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1340,19 +1462,19 @@ public class SDMath extends SDOps { */ public SDVariable erf(SDVariable x) { SDValidation.validateNumerical("erf", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd, x).outputVariable(); } /** * Element-wise Gaussian error function - out = erf(in)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable erf(String name, SDVariable x) { SDValidation.validateNumerical("erf", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1364,55 +1486,62 @@ public class SDMath extends SDOps { */ public SDVariable erfc(SDVariable x) { SDValidation.validateNumerical("erfc", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd, x).outputVariable(); } /** * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable erfc(String name, SDVariable x) { SDValidation.validateNumerical("erfc", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
- * tensor/subset along the specified dimensions:
- * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the + * Euclidean distance for each
tensor/subset along the specified dimensions:
out = sqrt( + * sum_i (x[i] - y[i])^2 )
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("euclideanDistance", "x", x); SDValidation.validateNumerical("euclideanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd, x, y, + dimensions).outputVariable(); } /** - * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
- * tensor/subset along the specified dimensions:
- * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the + * Euclidean distance for each
tensor/subset along the specified dimensions:
out = sqrt( + * sum_i (x[i] - y[i])^2 )
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("euclideanDistance", "x", x); SDValidation.validateNumerical("euclideanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1424,19 +1553,19 @@ public class SDMath extends SDOps { */ public SDVariable exp(SDVariable x) { SDValidation.validateNumerical("exp", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd, x).outputVariable(); } /** * Elementwise exponent function: out = exp(x) = 2.71828...^x
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable exp(String name, SDVariable x) { SDValidation.validateNumerical("exp", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1448,19 +1577,20 @@ public class SDMath extends SDOps { */ public SDVariable expm1(SDVariable x) { SDValidation.validateNumerical("expm1", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd, x).outputVariable(); } /** * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable expm1(String name, SDVariable x) { SDValidation.validateNumerical("expm1", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1471,7 +1601,7 @@ public class SDMath extends SDOps { * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(int rows) { - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); } /** @@ -1482,7 +1612,7 @@ public class SDMath extends SDOps { * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(String name, int rows) { - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1494,7 +1624,7 @@ public class SDMath extends SDOps { * @return output (NUMERIC type) */ public SDVariable eye(int rows, int cols) { - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); } /** @@ -1506,13 +1636,12 @@ public class SDMath extends SDOps { * @return output (NUMERIC type) */ public SDVariable eye(String name, int rows, int cols) { - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Generate an identity matrix with the specified number of rows and columns
- * Example:
+ * Generate an identity matrix with the specified number of rows and columns
Example:
*


* {@code INDArray eye = eye(3,2)
* eye:
@@ -1521,20 +1650,22 @@ public class SDMath extends SDOps { * [ 0, 0]}
*

* - * @param rows Number of rows - * @param cols Number of columns - * @param dataType Data type - * @param dimensions (Size: AtLeast(min=0)) + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(int rows, int cols, DataType dataType, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols, dataType, + dimensions).outputVariable(); } /** - * Generate an identity matrix with the specified number of rows and columns
- * Example:
+ * Generate an identity matrix with the specified number of rows and columns
Example:
*

* {@code INDArray eye = eye(3,2)
* eye:
@@ -1543,16 +1674,19 @@ public class SDMath extends SDOps { * [ 0, 0]}
*

* - * @param name name May be null. Name for the output variable - * @param rows Number of rows - * @param cols Number of columns - * @param dataType Data type - * @param dimensions (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(String name, int rows, int cols, DataType dataType, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols, dataType, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1566,7 +1700,7 @@ public class SDMath extends SDOps { public SDVariable eye(SDVariable rows, SDVariable cols) { SDValidation.validateInteger("eye", "rows", rows); SDValidation.validateInteger("eye", "cols", cols); - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); } /** @@ -1580,7 +1714,7 @@ public class SDMath extends SDOps { public SDVariable eye(String name, SDVariable rows, SDVariable cols) { SDValidation.validateInteger("eye", "rows", rows); SDValidation.validateInteger("eye", "cols", cols); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1592,7 +1726,7 @@ public class SDMath extends SDOps { */ public SDVariable eye(SDVariable rows) { SDValidation.validateInteger("eye", "rows", rows); - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); } /** @@ -1604,138 +1738,149 @@ public class SDMath extends SDOps { */ public SDVariable eye(String name, SDVariable rows) { SDValidation.validateInteger("eye", "rows", rows); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, false, condition, + dimensions).outputVariable(); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, false, + condition, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, keepDims, condition, + dimensions).outputVariable(); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, keepDims, + condition, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise floor function: out = floor(x).
- * Rounds each value down to the nearest integer value (if not already an integer)
+ * Element-wise floor function: out = floor(x).
Rounds each value down to the nearest integer + * value (if not already an integer)
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floor(SDVariable x) { SDValidation.validateNumerical("floor", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd, x).outputVariable(); } /** - * Element-wise floor function: out = floor(x).
- * Rounds each value down to the nearest integer value (if not already an integer)
+ * Element-wise floor function: out = floor(x).
Rounds each value down to the nearest integer + * value (if not already an integer)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floor(String name, SDVariable x) { SDValidation.validateNumerical("floor", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise floor division operation, out = floor(x / y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -1744,34 +1889,38 @@ public class SDMath extends SDOps { public SDVariable floorDiv(SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorDiv", "x", x); SDValidation.validateNumerical("floorDiv", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd, x, + y).outputVariable(); } /** * Pairwise floor division operation, out = floor(x / y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floorDiv(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorDiv", "x", x); SDValidation.validateNumerical("floorDiv", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd, + x, y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise Modulus division operation
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -1780,509 +1929,564 @@ public class SDMath extends SDOps { public SDVariable floorMod(SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorMod", "x", x); SDValidation.validateNumerical("floorMod", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd, x, + y).outputVariable(); } /** * Pairwise Modulus division operation
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floorMod(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorMod", "x", x); SDValidation.validateNumerical("floorMod", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd, + x, y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar floor modulus operation
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable floorMod(SDVariable x, double value) { SDValidation.validateNumerical("floorMod", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd, x, value).outputVariable(); } /** * Scalar floor modulus operation
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable floorMod(String name, SDVariable x, double value) { SDValidation.validateNumerical("floorMod", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Hamming distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = count( x[i] != y[i] )
+ * tensor/subset along the specified dimensions:
out = count( x[i] != y[i] )
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("hammingDistance", "x", x); SDValidation.validateNumerical("hammingDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd, x, y, + dimensions).outputVariable(); } /** * Hamming distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = count( x[i] != y[i] )
+ * tensor/subset along the specified dimensions:
out = count( x[i] != y[i] )
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("hammingDistance", "x", x); SDValidation.validateNumerical("hammingDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, false, + dimensions).outputVariable(); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, false, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, keepDims, + dimensions).outputVariable(); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, keepDims, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, false, + dimensions).outputVariable(); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, false, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, keepDims, + dimensions).outputVariable(); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, keepDims, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is finite operation: elementwise isFinite(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isFinite(SDVariable x) { SDValidation.validateNumerical("isFinite", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd, x).outputVariable(); } /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is finite operation: elementwise isFinite(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isFinite(String name, SDVariable x) { SDValidation.validateNumerical("isFinite", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is infinite operation: elementwise isInfinite(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isInfinite(SDVariable x) { SDValidation.validateNumerical("isInfinite", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd, x).outputVariable(); } /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is infinite operation: elementwise isInfinite(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isInfinite(String name, SDVariable x) { SDValidation.validateNumerical("isInfinite", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is maximum operation: elementwise x == max(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isMax(SDVariable x) { SDValidation.validateNumerical("isMax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd, x).outputVariable(); } /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is maximum operation: elementwise x == max(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isMax(String name, SDVariable x) { SDValidation.validateNumerical("isMax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is Not a Number operation: elementwise isNaN(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isNaN(SDVariable x) { SDValidation.validateNumerical("isNaN", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd, x).outputVariable(); } /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is Not a Number operation: elementwise isNaN(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isNaN(String name, SDVariable x) { SDValidation.validateNumerical("isNaN", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array non decreasing?
An array is non-decreasing if for every valid i, x[i] <= + * x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param x Input variable (NUMERIC type) * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) */ public SDVariable isNonDecreasing(SDVariable x) { SDValidation.validateNumerical("isNonDecreasing", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd, + x).outputVariable(); } /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array non decreasing?
An array is non-decreasing if for every valid i, x[i] <= + * x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) */ public SDVariable isNonDecreasing(String name, SDVariable x) { SDValidation.validateNumerical("isNonDecreasing", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array strictly increasing?
An array is strictly increasing if for every valid i, + * x[i] < x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param x Input variable (NUMERIC type) - * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC + * type) */ public SDVariable isStrictlyIncreasing(SDVariable x) { SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd, + x).outputVariable(); } /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array strictly increasing?
An array is strictly increasing if for every valid i, + * x[i] < x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC + * type) */ public SDVariable isStrictlyIncreasing(String name, SDVariable x) { SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
- * tensor along the specified dimensions.
+ * tensor along the specified dimensions.
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("jaccardDistance", "x", x); SDValidation.validateNumerical("jaccardDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd, x, y, + dimensions).outputVariable(); } /** * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
- * tensor along the specified dimensions.
+ * tensor along the specified dimensions.
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("jaccardDistance", "x", x); SDValidation.validateNumerical("jaccardDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, false, condition, + dimensions).outputVariable(); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, false, condition, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, keepDims, condition, + dimensions).outputVariable(); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, keepDims, + condition, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2295,20 +2499,21 @@ public class SDMath extends SDOps { public SDVariable[] listDiff(SDVariable x, SDVariable y) { SDValidation.validateNumerical("listDiff", "x", x); SDValidation.validateNumerical("listDiff", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd, x, y).outputVariables(); } /** * Calculates difference between inputs X and Y.
* * @param names names May be null. Arrays of names for the output variables. - * @param x Input variable X (NUMERIC type) - * @param y Input variable Y (NUMERIC type) + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) */ public SDVariable[] listDiff(String[] names, SDVariable x, SDVariable y) { SDValidation.validateNumerical("listDiff", "x", x); SDValidation.validateNumerical("listDiff", "y", y); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd, x, + y).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } @@ -2320,45 +2525,45 @@ public class SDMath extends SDOps { */ public SDVariable log(SDVariable x) { SDValidation.validateNumerical("log", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd, x).outputVariable(); } /** * Element-wise logarithm function (base e - natural logarithm): out = log(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable log(String name, SDVariable x) { SDValidation.validateNumerical("log", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise logarithm function (with specified base): out = log_{base}(x)
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param base Logarithm base * @return output Output variable (NUMERIC type) */ public SDVariable log(SDVariable x, double base) { SDValidation.validateNumerical("log", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd, x, base).outputVariable(); } /** * Element-wise logarithm function (with specified base): out = log_{base}(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param base Logarithm base * @return output Output variable (NUMERIC type) */ public SDVariable log(String name, SDVariable x, double base) { SDValidation.validateNumerical("log", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd, x, base).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2370,178 +2575,202 @@ public class SDMath extends SDOps { */ public SDVariable log1p(SDVariable x) { SDValidation.validateNumerical("log1p", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd, x).outputVariable(); } /** * Elementwise natural logarithm function: out = log_e (1 + x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable log1p(String name, SDVariable x) { SDValidation.validateNumerical("log1p", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Log entropy reduction: log(-sum(x * log(x)))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable logEntropy(SDVariable in, int... dimensions) { SDValidation.validateNumerical("logEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd, in, + dimensions).outputVariable(); } /** * Log entropy reduction: log(-sum(x * log(x)))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("logEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Log-sum-exp reduction (optionally along dimension).
- * Computes log(sum(exp(x))
+ * Log-sum-exp reduction (optionally along dimension).
Computes log(sum(exp(x))
* - * @param input Input variable (NUMERIC type) + * @param input Input variable (NUMERIC type) * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable logSumExp(SDVariable input, int... dimensions) { SDValidation.validateNumerical("logSumExp", "input", input); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd, input, + dimensions).outputVariable(); } /** - * Log-sum-exp reduction (optionally along dimension).
- * Computes log(sum(exp(x))
+ * Log-sum-exp reduction (optionally along dimension).
Computes log(sum(exp(x))
* - * @param name name May be null. Name for the output variable - * @param input Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { SDValidation.validateNumerical("logSumExp", "input", input); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd, input, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
- * tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]-y[i])
+ * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the + * Manhattan distance for each
tensor/subset along the specified dimensions:
out = sum_i + * abs(x[i]-y[i])
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("manhattanDistance", "x", x); SDValidation.validateNumerical("manhattanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd, x, y, + dimensions).outputVariable(); } /** - * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
- * tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]-y[i])
+ * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the + * Manhattan distance for each
tensor/subset along the specified dimensions:
out = sum_i + * abs(x[i]-y[i])
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("manhattanDistance", "x", x); SDValidation.validateNumerical("manhattanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
- * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
For + * higher dimensional input with shape [..., m, m] the matrix determinant is returned for each + *
shape [m,m] sub-matrix.
* * @param in Input (NUMERIC type) * @return output Matrix determinant variable (NUMERIC type) */ public SDVariable matrixDeterminant(SDVariable in) { SDValidation.validateNumerical("matrixDeterminant", "in", in); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd, + in).outputVariable(); } /** - * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
- * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
For + * higher dimensional input with shape [..., m, m] the matrix determinant is returned for each + *
shape [m,m] sub-matrix.
* * @param name name May be null. Name for the output variable - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @return output Matrix determinant variable (NUMERIC type) */ public SDVariable matrixDeterminant(String name, SDVariable in) { SDValidation.validateNumerical("matrixDeterminant", "in", in); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd, + in).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
- * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
For higher + * dimensional input with shape [..., m, m] the matrix inverse is returned for each
shape + * [m,m] sub-matrix.
* * @param in Input (NUMERIC type) * @return output Matrix inverse variable (NUMERIC type) */ public SDVariable matrixInverse(SDVariable in) { SDValidation.validateNumerical("matrixInverse", "in", in); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd, + in).outputVariable(); } /** - * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
- * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
For higher + * dimensional input with shape [..., m, m] the matrix inverse is returned for each
shape + * [m,m] sub-matrix.
* * @param name name May be null. Name for the output variable - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @return output Matrix inverse variable (NUMERIC type) */ public SDVariable matrixInverse(String name, SDVariable in) { SDValidation.validateNumerical("matrixInverse", "in", in); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd, + in).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise max operation, out = max(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x First input variable, x (NUMERIC type) * @param y Second input variable, y (NUMERIC type) @@ -2550,144 +2779,158 @@ public class SDMath extends SDOps { public SDVariable max(SDVariable x, SDVariable y) { SDValidation.validateNumerical("max", "x", x); SDValidation.validateNumerical("max", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, x, y).outputVariable(); } /** * Pairwise max operation, out = max(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x First input variable, x (NUMERIC type) - * @param y Second input variable, y (NUMERIC type) + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) * @return out Output (NUMERIC type) */ public SDVariable max(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("max", "x", x); SDValidation.validateNumerical("max", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
- * out = sum_i in[i]
+ * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise + * addition:
out = sum_i in[i]
* * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAdd(SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd, + inputs).outputVariable(); } /** - * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
- * out = sum_i in[i]
+ * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise + * addition:
out = sum_i in[i]
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAdd(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd, + inputs).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
- * out = mean_i in[i]
+ * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise + * mean operation:
out = mean_i in[i]
* * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAvg(SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd, inputs).outputVariable(); } /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
- * out = mean_i in[i]
+ * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise + * mean operation:
out = mean_i in[i]
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAvg(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd, inputs).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
- * out = max_i in[i]
+ * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise + * maximum operation:
out = max_i in[i]
* * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeMax(SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd, inputs).outputVariable(); } /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
- * out = max_i in[i]
+ * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise + * maximum operation:
out = max_i in[i]
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeMax(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd, inputs).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Broadcasts parameters for evaluation on an N-D grid.
* - * @param inputs (NUMERIC type) - * @param cartesian + * @param inputs (NUMERIC type) + * @param cartesian */ public SDVariable[] meshgrid(SDVariable[] inputs, boolean cartesian) { SDValidation.validateNumerical("meshgrid", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + Preconditions.checkArgument(inputs.length >= 0, + "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd, inputs, cartesian).outputVariables(); } /** * Broadcasts parameters for evaluation on an N-D grid.
* - * @param names names May be null. Arrays of names for the output variables. - * @param inputs (NUMERIC type) - * @param cartesian + * @param names names May be null. Arrays of names for the output variables. + * @param inputs (NUMERIC type) + * @param cartesian */ public SDVariable[] meshgrid(String[] names, SDVariable[] inputs, boolean cartesian) { SDValidation.validateNumerical("meshgrid", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + Preconditions.checkArgument(inputs.length >= 0, + "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd, inputs, + cartesian).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } /** * Pairwise max operation, out = min(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x First input variable, x (NUMERIC type) * @param y Second input variable, y (NUMERIC type) @@ -2696,34 +2939,37 @@ public class SDMath extends SDOps { public SDVariable min(SDVariable x, SDVariable y) { SDValidation.validateNumerical("min", "x", x); SDValidation.validateNumerical("min", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd, x, y).outputVariable(); } /** * Pairwise max operation, out = min(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x First input variable, x (NUMERIC type) - * @param y Second input variable, y (NUMERIC type) + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) * @return out Output (NUMERIC type) */ public SDVariable min(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("min", "x", x); SDValidation.validateNumerical("min", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise modulus (remainder) operation, out = x % y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -2732,60 +2978,69 @@ public class SDMath extends SDOps { public SDVariable mod(SDVariable x, SDVariable y) { SDValidation.validateNumerical("mod", "x", x); SDValidation.validateNumerical("mod", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd, x, + y).outputVariable(); } /** * Pairwise modulus (remainder) operation, out = x % y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mod(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("mod", "x", x); SDValidation.validateNumerical("mod", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * Calculate the mean and (population) variance for the input variable, for the specified + * axis
* * @param input Input to calculate moments for (NUMERIC type) - * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) */ public SDVariable[] moments(SDVariable input, int... axes) { SDValidation.validateNumerical("moments", "input", input); - Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); - return new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + Preconditions.checkArgument(axes.length >= 0, + "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + return new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd, input, axes).outputVariables(); } /** - * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * Calculate the mean and (population) variance for the input variable, for the specified + * axis
* * @param names names May be null. Arrays of names for the output variables. * @param input Input to calculate moments for (NUMERIC type) - * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) */ public SDVariable[] moments(String[] names, SDVariable input, int... axes) { SDValidation.validateNumerical("moments", "input", input); - Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + Preconditions.checkArgument(axes.length >= 0, + "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd, input, + axes).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } /** * Pairwise multiplication operation, out = x * y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -2794,51 +3049,56 @@ public class SDMath extends SDOps { public SDVariable mul(SDVariable x, SDVariable y) { SDValidation.validateNumerical("mul", "x", x); SDValidation.validateNumerical("mul", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd, x, + y).outputVariable(); } /** * Pairwise multiplication operation, out = x * y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mul(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("mul", "x", x); SDValidation.validateNumerical("mul", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar multiplication operation, out = in * scalar
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable mul(SDVariable x, double value) { SDValidation.validateNumerical("mul", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd, x, + value).outputVariable(); } /** * Scalar multiplication operation, out = in * scalar
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable mul(String name, SDVariable x, double value) { SDValidation.validateNumerical("mul", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2850,113 +3110,127 @@ public class SDMath extends SDOps { */ public SDVariable neg(SDVariable x) { SDValidation.validateNumerical("neg", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd, x).outputVariable(); } /** * Elementwise negative operation: out = -x
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable neg(String name, SDVariable x) { SDValidation.validateNumerical("neg", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Calculate the mean and variance from the sufficient statistics
* - * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) - * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) - * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) - * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the + * sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC + * type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values + * (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for + * numerical stability) */ public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { SDValidation.validateNumerical("normalizeMoments", "counts", counts); SDValidation.validateNumerical("normalizeMoments", "means", means); SDValidation.validateNumerical("normalizeMoments", "variances", variances); - return new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + return new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd, counts, means, variances, + shift).outputVariables(); } /** * Calculate the mean and variance from the sufficient statistics
* - * @param names names May be null. Arrays of names for the output variables. - * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) - * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) - * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) - * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + * @param names names May be null. Arrays of names for the output variables. + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the + * sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC + * type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values + * (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for + * numerical stability) */ public SDVariable[] normalizeMoments(String[] names, SDVariable counts, SDVariable means, SDVariable variances, double shift) { SDValidation.validateNumerical("normalizeMoments", "counts", counts); SDValidation.validateNumerical("normalizeMoments", "means", means); SDValidation.validateNumerical("normalizeMoments", "variances", variances); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd, counts, means, + variances, shift).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean OR operation: elementwise (x != 0) || (y != 0)
If x and y arrays have equal shape, + * the output shape is the same as these inputs.
Note: supports broadcasting if x and y have + * different shapes and are broadcastable.
Returns an array with values 1 where condition is + * satisfied, or value 0 otherwise.
* * @param x Input 1 (BOOL type) * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable or(SDVariable x, SDVariable y) { SDValidation.validateBool("or", "x", x); SDValidation.validateBool("or", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd, x, y).outputVariable(); } /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean OR operation: elementwise (x != 0) || (y != 0)
If x and y arrays have equal shape, + * the output shape is the same as these inputs.
Note: supports broadcasting if x and y have + * different shapes and are broadcastable.
Returns an array with values 1 where condition is + * satisfied, or value 0 otherwise.
* * @param name name May be null. Name for the output variable - * @param x Input 1 (BOOL type) - * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable or(String name, SDVariable x, SDVariable y) { SDValidation.validateBool("or", "x", x); SDValidation.validateBool("or", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise power function: out = x^value
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable pow(SDVariable x, double value) { SDValidation.validateNumerical("pow", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd, x, value).outputVariable(); } /** * Element-wise power function: out = x^value
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable pow(String name, SDVariable x, double value) { SDValidation.validateNumerical("pow", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd, x, value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2970,58 +3244,61 @@ public class SDMath extends SDOps { public SDVariable pow(SDVariable x, SDVariable y) { SDValidation.validateNumerical("pow", "x", x); SDValidation.validateNumerical("pow", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd, x, y).outputVariable(); } /** * Element-wise (broadcastable) power function: out = x[i]^y[i]
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Power (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable pow(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("pow", "x", x); SDValidation.validateNumerical("pow", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Rational Tanh Approximation elementwise function, as described in the paper:
- * Compact Convolutional Neural Network Cascade for Face Detection
- * This is a faster Tanh approximation
+ * Rational Tanh Approximation elementwise function, as described in the paper:
Compact + * Convolutional Neural Network Cascade for Face Detection
This is a faster Tanh + * approximation
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rationalTanh(SDVariable x) { SDValidation.validateNumerical("rationalTanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd, x).outputVariable(); } /** - * Rational Tanh Approximation elementwise function, as described in the paper:
- * Compact Convolutional Neural Network Cascade for Face Detection
- * This is a faster Tanh approximation
+ * Rational Tanh Approximation elementwise function, as described in the paper:
Compact + * Convolutional Neural Network Cascade for Face Detection
This is a faster Tanh + * approximation
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rationalTanh(String name, SDVariable x) { SDValidation.validateNumerical("rationalTanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise reverse division operation, out = y / x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3030,51 +3307,56 @@ public class SDMath extends SDOps { public SDVariable rdiv(SDVariable x, SDVariable y) { SDValidation.validateNumerical("rdiv", "x", x); SDValidation.validateNumerical("rdiv", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd, x, + y).outputVariable(); } /** * Pairwise reverse division operation, out = y / x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rdiv(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("rdiv", "x", x); SDValidation.validateNumerical("rdiv", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar reverse division operation, out = scalar / in
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rdiv(SDVariable x, double value) { SDValidation.validateNumerical("rdiv", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd, x, + value).outputVariable(); } /** * Scalar reverse division operation, out = scalar / in
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rdiv(String name, SDVariable x, double value) { SDValidation.validateNumerical("rdiv", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3086,19 +3368,20 @@ public class SDMath extends SDOps { */ public SDVariable reciprocal(SDVariable x) { SDValidation.validateNumerical("reciprocal", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd, x).outputVariable(); } /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable reciprocal(String name, SDVariable x) { SDValidation.validateNumerical("reciprocal", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3110,45 +3393,46 @@ public class SDMath extends SDOps { */ public SDVariable rectifiedTanh(SDVariable x) { SDValidation.validateNumerical("rectifiedTanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd, x).outputVariable(); } /** * Rectified tanh operation: max(0, tanh(in))
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rectifiedTanh(String name, SDVariable x) { SDValidation.validateNumerical("rectifiedTanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise round function: out = round(x).
- * Rounds (up or down depending on value) to the nearest integer value.
+ * Element-wise round function: out = round(x).
Rounds (up or down depending on value) to the + * nearest integer value.
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable round(SDVariable x) { SDValidation.validateNumerical("round", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd, x).outputVariable(); } /** - * Element-wise round function: out = round(x).
- * Rounds (up or down depending on value) to the nearest integer value.
+ * Element-wise round function: out = round(x).
Rounds (up or down depending on value) to the + * nearest integer value.
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable round(String name, SDVariable x) { SDValidation.validateNumerical("round", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3160,28 +3444,30 @@ public class SDMath extends SDOps { */ public SDVariable rsqrt(SDVariable x) { SDValidation.validateNumerical("rsqrt", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd, x).outputVariable(); } /** * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rsqrt(String name, SDVariable x) { SDValidation.validateNumerical("rsqrt", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise reverse subtraction operation, out = y - x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3190,153 +3476,152 @@ public class SDMath extends SDOps { public SDVariable rsub(SDVariable x, SDVariable y) { SDValidation.validateNumerical("rsub", "x", x); SDValidation.validateNumerical("rsub", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd, x, + y).outputVariable(); } /** * Pairwise reverse subtraction operation, out = y - x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rsub(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("rsub", "x", x); SDValidation.validateNumerical("rsub", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar reverse subtraction operation, out = scalar - in
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rsub(SDVariable x, double value) { SDValidation.validateNumerical("rsub", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd, x, + value).outputVariable(); } /** * Scalar reverse subtraction operation, out = scalar - in
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rsub(String name, SDVariable x, double value) { SDValidation.validateNumerical("rsub", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Set the diagonal value to the specified values
- * If input is
- * [ a, b, c]
- * [ d, e, f]
- * [ g, h, i]
- * and diag = [ 1, 2, 3] then output is
- * [ 1, b, c]
- * [ d, 2, f]
- * [ g, h, 3]
+ * Set the diagonal value to the specified values
If input is
[ a, b, c]
[ d, e, + * f]
[ g, h, i]
and diag = [ 1, 2, 3] then output is
[ 1, b, c]
[ d, 2, f]
[ + * g, h, 3]
* - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param diag Diagonal (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable setDiag(SDVariable in, SDVariable diag) { SDValidation.validateNumerical("setDiag", "in", in); SDValidation.validateNumerical("setDiag", "diag", diag); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd, in, + diag).outputVariable(); } /** - * Set the diagonal value to the specified values
- * If input is
- * [ a, b, c]
- * [ d, e, f]
- * [ g, h, i]
- * and diag = [ 1, 2, 3] then output is
- * [ 1, b, c]
- * [ d, 2, f]
- * [ g, h, 3]
+ * Set the diagonal value to the specified values
If input is
[ a, b, c]
[ d, e, + * f]
[ g, h, i]
and diag = [ 1, 2, 3] then output is
[ 1, b, c]
[ d, 2, f]
[ + * g, h, 3]
* * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param diag Diagonal (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable setDiag(String name, SDVariable in, SDVariable diag) { SDValidation.validateNumerical("setDiag", "in", in); SDValidation.validateNumerical("setDiag", "diag", diag); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd, in, + diag).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Shannon Entropy reduction: -sum(x * log2(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable shannonEntropy(SDVariable in, int... dimensions) { SDValidation.validateNumerical("shannonEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd, in, + dimensions).outputVariable(); } /** * Shannon Entropy reduction: -sum(x * log2(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("shannonEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0
+ * Element-wise sign (signum) function:
out = -1 if in < 0
out = 0 if in = 0
out = 1 + * if in > 0
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sign(SDVariable x) { SDValidation.validateNumerical("sign", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd, x).outputVariable(); } /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0
+ * Element-wise sign (signum) function:
out = -1 if in < 0
out = 0 if in = 0
out = 1 + * if in > 0
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sign(String name, SDVariable x) { SDValidation.validateNumerical("sign", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3348,19 +3633,19 @@ public class SDMath extends SDOps { */ public SDVariable sin(SDVariable x) { SDValidation.validateNumerical("sin", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd, x).outputVariable(); } /** * Elementwise sine operation: out = sin(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sin(String name, SDVariable x) { SDValidation.validateNumerical("sin", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3372,19 +3657,20 @@ public class SDMath extends SDOps { */ public SDVariable sinh(SDVariable x) { SDValidation.validateNumerical("sinh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd, x).outputVariable(); } /** * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sinh(String name, SDVariable x) { SDValidation.validateNumerical("sinh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3396,19 +3682,20 @@ public class SDMath extends SDOps { */ public SDVariable sqrt(SDVariable x) { SDValidation.validateNumerical("sqrt", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd, x).outputVariable(); } /** * Element-wise square root function: out = sqrt(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sqrt(String name, SDVariable x) { SDValidation.validateNumerical("sqrt", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3420,28 +3707,30 @@ public class SDMath extends SDOps { */ public SDVariable square(SDVariable x) { SDValidation.validateNumerical("square", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd, x).outputVariable(); } /** * Element-wise square function: out = x^2
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable square(String name, SDVariable x) { SDValidation.validateNumerical("square", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise squared difference operation.
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3450,25 +3739,28 @@ public class SDMath extends SDOps { public SDVariable squaredDifference(SDVariable x, SDVariable y) { SDValidation.validateNumerical("squaredDifference", "x", x); SDValidation.validateNumerical("squaredDifference", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd, + x, y).outputVariable(); } /** * Pairwise squared difference operation.
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable squaredDifference(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("squaredDifference", "x", x); SDValidation.validateNumerical("squaredDifference", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp( + sd, x, y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3480,20 +3772,23 @@ public class SDMath extends SDOps { * with mean and stdev being calculated along the given dimension.
*


* For example: given x as a mini batch of the shape [numExamples, exampleLength]:
- *


    - *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • - *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • - *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *
      + *
    • use dimension 1 too use the statistics (mean, stdev) for each example
    • + *
    • use dimension 0 if you want to use the statistics for each column across all examples
    • + *
    • use dimensions 0,1 if you want to use the statistics across all columns and examples
    • *

    * - * @param x Input variable (NUMERIC type) - * @param dimensions (Size: AtLeast(min=1)) + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable standardize(SDVariable x, int... dimensions) { SDValidation.validateNumerical("standardize", "x", x); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd, x, + dimensions).outputVariable(); } /** @@ -3504,60 +3799,60 @@ public class SDMath extends SDOps { * with mean and stdev being calculated along the given dimension.
    *


    * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
    - *


      - *
    • use dimension 1 too use the statistics (mean, stdev) for each example

    • - *
    • use dimension 0 if you want to use the statistics for each column across all examples

    • - *
    • use dimensions 0,1 if you want to use the statistics across all columns and examples

    • + *
        + *
      • use dimension 1 too use the statistics (mean, stdev) for each example
      • + *
      • use dimension 0 if you want to use the statistics for each column across all examples
      • + *
      • use dimensions 0,1 if you want to use the statistics across all columns and examples
      • *

      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param dimensions (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable standardize(String name, SDVariable x, int... dimensions) { SDValidation.validateNumerical("standardize", "x", x); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd, x, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Elementwise step function:
      - * out(x) = 1 if x >= cutoff
      - * out(x) = 0 otherwise
      + * Elementwise step function:
      {@code out(x) = 1 if x >= cutoff
      out(x) = 0 otherwise}
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable step(SDVariable x, double value) { SDValidation.validateNumerical("step", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.Step(sd, x, value).outputVariable(); } /** - * Elementwise step function:
      - * out(x) = 1 if x >= cutoff
      - * out(x) = 0 otherwise
      + * Elementwise step function:
      {@code out(x) = 1 if x >= cutoff
      out(x) = 0 otherwise} * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable step(String name, SDVariable x, double value) { SDValidation.validateNumerical("step", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Step(sd, x, value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise subtraction operation, out = x - y
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3566,51 +3861,55 @@ public class SDMath extends SDOps { public SDVariable sub(SDVariable x, SDVariable y) { SDValidation.validateNumerical("sub", "x", x); SDValidation.validateNumerical("sub", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd, x, + y).outputVariable(); } /** * Pairwise subtraction operation, out = x - y
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sub(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("sub", "x", x); SDValidation.validateNumerical("sub", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar subtraction operation, out = in - scalar
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable sub(SDVariable x, double value) { SDValidation.validateNumerical("sub", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd, x, value).outputVariable(); } /** * Scalar subtraction operation, out = in - scalar
      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable sub(String name, SDVariable x, double value) { SDValidation.validateNumerical("sub", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3622,19 +3921,19 @@ public class SDMath extends SDOps { */ public SDVariable tan(SDVariable x) { SDValidation.validateNumerical("tan", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd, x).outputVariable(); } /** * Elementwise tangent operation: out = tan(x)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable tan(String name, SDVariable x) { SDValidation.validateNumerical("tan", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3646,105 +3945,111 @@ public class SDMath extends SDOps { */ public SDVariable tanh(SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, x).outputVariable(); } /** * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable tanh(String name, SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Matrix trace operation
      - * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
      - * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      + * Matrix trace operation
      For rank 2 matrices, the output is a scalar vith the trace - i.e., + * sum of the main diagonal.
      For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      * * @param in Input variable (NUMERIC type) * @return output Trace (NUMERIC type) */ public SDVariable trace(SDVariable in) { SDValidation.validateNumerical("trace", "in", in); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd, in).outputVariable(); } /** - * Matrix trace operation
      - * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
      - * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      + * Matrix trace operation
      For rank 2 matrices, the output is a scalar vith the trace - i.e., + * sum of the main diagonal.
      For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      * * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @return output Trace (NUMERIC type) */ public SDVariable trace(String name, SDVariable in) { SDValidation.validateNumerical("trace", "in", in); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd, + in).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
      + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      If x and y arrays + * have equal shape, the output shape is the same as these inputs.
      Note: supports broadcasting + * if x and y have different shapes and are broadcastable.
      Returns an array with values 1 + * where condition is satisfied, or value 0 otherwise.
      * * @param x Input 1 (BOOL type) * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable xor(SDVariable x, SDVariable y) { SDValidation.validateBool("xor", "x", x); SDValidation.validateBool("xor", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd, x, y).outputVariable(); } /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
      + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      If x and y arrays + * have equal shape, the output shape is the same as these inputs.
      Note: supports broadcasting + * if x and y have different shapes and are broadcastable.
      Returns an array with values 1 + * where condition is satisfied, or value 0 otherwise.
      * * @param name name May be null. Name for the output variable - * @param x Input 1 (BOOL type) - * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable xor(String name, SDVariable x, SDVariable y) { SDValidation.validateBool("xor", "x", x); SDValidation.validateBool("xor", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
      + * Full array zero fraction array reduction operation, optionally along specified dimensions: out + * = (count(x == 0) / length(x))
      * * @param input Input variable (NUMERIC type) * @return output Reduced array of rank 0 (scalar) (NUMERIC type) */ public SDVariable zeroFraction(SDVariable input) { SDValidation.validateNumerical("zeroFraction", "input", input); - return new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd, input).outputVariable(); } /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
      + * Full array zero fraction array reduction operation, optionally along specified dimensions: out + * = (count(x == 0) / length(x))
      * - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param input Input variable (NUMERIC type) * @return output Reduced array of rank 0 (scalar) (NUMERIC type) */ public SDVariable zeroFraction(String name, SDVariable input) { SDValidation.validateNumerical("zeroFraction", "input", input); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd, + input).outputVariable(); return sd.updateVariableNameAndReference(out, name); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 846291e47..b617d2865 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -27,47 +27,53 @@ import org.nd4j.common.base.Preconditions; import org.nd4j.enums.PadMode; public class SDNN extends SDOps { + public SDNN(SameDiff sameDiff) { super(sameDiff); } /** - * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
      + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which + * selects only the negative part of the activation. Note that as a result this non-linearity + * doubles the depth of the activations.
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cReLU(SDVariable x) { SDValidation.validateNumerical("CReLU", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd, x).outputVariable(); } /** - * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
      + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which + * selects only the negative part of the activation. Note that as a result this non-linearity + * doubles the depth of the activations.
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cReLU(String name, SDVariable x) { SDValidation.validateNumerical("CReLU", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Neural network batch normalization operation.
      - * For details, see https://arxiv.org/abs/1502.03167
      + * Neural network batch normalization operation.
      For details, see https://arxiv.org/abs/1502.03167
      * - * @param input Input variable. (NUMERIC type) - * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) - * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. - * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC - * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format + * activations. For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC For 1d/RNN + * activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) * @return output variable for batch normalization (NUMERIC type) */ public SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, @@ -77,24 +83,26 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("batchNorm", "variance", variance); SDValidation.validateNumerical("batchNorm", "gamma", gamma); SDValidation.validateNumerical("batchNorm", "beta", beta); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd, input, mean, variance, + gamma, beta, epsilon, axis).outputVariable(); } /** - * Neural network batch normalization operation.
      - * For details, see https://arxiv.org/abs/1502.03167
      + * Neural network batch normalization operation.
      For details, see https://arxiv.org/abs/1502.03167
      * - * @param name name May be null. Name for the output variable - * @param input Input variable. (NUMERIC type) - * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) - * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. - * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC - * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format + * activations. For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC For 1d/RNN + * activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) * @return output variable for batch normalization (NUMERIC type) */ public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, @@ -104,73 +112,82 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("batchNorm", "variance", variance); SDValidation.validateNumerical("batchNorm", "gamma", gamma); SDValidation.validateNumerical("batchNorm", "beta", beta); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd, input, mean, + variance, gamma, beta, epsilon, axis).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
      + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and + * a 1D bias vector
      * * @param input 4d input variable (NUMERIC type) - * @param bias 1d bias (NUMERIC type) - * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. - * Unused for 2d inputs + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; + * nchw=false - [minibatch, height, width, channels]. Unused for 2d inputs * @return output Output variable, after applying bias add operation (NUMERIC type) */ public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { SDValidation.validateNumerical("biasAdd", "input", input); SDValidation.validateNumerical("biasAdd", "bias", bias); - return new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd, input, bias, + nchw).outputVariable(); } /** - * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
      + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and + * a 1D bias vector
      * - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param input 4d input variable (NUMERIC type) - * @param bias 1d bias (NUMERIC type) - * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. - * Unused for 2d inputs + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; + * nchw=false - [minibatch, height, width, channels]. Unused for 2d inputs * @return output Output variable, after applying bias add operation (NUMERIC type) */ public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { SDValidation.validateNumerical("biasAdd", "input", input); SDValidation.validateNumerical("biasAdd", "bias", bias); - SDVariable out = new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd, input, bias, + nchw).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * This operation performs dot product attention on the given timeseries input with the given queries
      - * out = sum(similarity(k_i, q) * v_i)
      + * This operation performs dot product attention on the given timeseries input with the given + * queries
      out = sum(similarity(k_i, q) * v_i)
      *
      * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
      *
      - * Optionally with normalization step:
      - * similarity(k, q) = softmax(k * q / sqrt(size(q))
      + * Optionally with normalization step:
      similarity(k, q) = softmax(k * q / sqrt(size(q))
      *
      * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
      *
      - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
      - * be 3D but can have queryCount = 1
      + * Note: This supports multiple queries at once, if only one query is available the queries vector + * still has to
      be 3D but can have queryCount = 1
      *
      - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
      - * both.
      + * Note: keys and values usually is the same array. If you want to use it as the same array, + * simply pass it for
      both.
      *
      - * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
      - * output rank will depend on the input rank.
      + * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them + * doesn't work. The
      output rank will depend on the input rank.
      * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], - * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D + * array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array + * of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D + * array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply + * normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount], (optionally) Attention Weights of shape + * [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC + * type) */ public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) { @@ -178,40 +195,44 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("dotProductAttention", "keys", keys); SDValidation.validateNumerical("dotProductAttention", "values", values); SDValidation.validateNumerical("dotProductAttention", "mask", mask); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd, queries, keys, + values, mask, scaled, false).outputVariable(); } /** - * This operation performs dot product attention on the given timeseries input with the given queries
      - * out = sum(similarity(k_i, q) * v_i)
      + * This operation performs dot product attention on the given timeseries input with the given + * queries
      out = sum(similarity(k_i, q) * v_i)
      *
      * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
      *
      - * Optionally with normalization step:
      - * similarity(k, q) = softmax(k * q / sqrt(size(q))
      + * Optionally with normalization step:
      similarity(k, q) = softmax(k * q / sqrt(size(q))
      *
      * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
      *
      - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
      - * be 3D but can have queryCount = 1
      + * Note: This supports multiple queries at once, if only one query is available the queries vector + * still has to
      be 3D but can have queryCount = 1
      *
      - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
      - * both.
      + * Note: keys and values usually is the same array. If you want to use it as the same array, + * simply pass it for
      both.
      *
      - * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
      - * output rank will depend on the input rank.
      + * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them + * doesn't work. The
      output rank will depend on the input rank.
      * - * @param name name May be null. Name for the output variable - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], - * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D + * array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array + * of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D + * array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply + * normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount], (optionally) Attention Weights of shape + * [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC + * type) */ public SDVariable dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) { @@ -219,41 +240,44 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("dotProductAttention", "keys", keys); SDValidation.validateNumerical("dotProductAttention", "values", values); SDValidation.validateNumerical("dotProductAttention", "mask", mask); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd, + queries, keys, values, mask, scaled, false).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Dropout operation
      * - * @param input Input array (NUMERIC type) - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability + * 1-p) * @return output Output (NUMERIC type) */ public SDVariable dropout(SDVariable input, double inputRetainProbability) { SDValidation.validateNumerical("dropout", "input", input); - return new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + return new org.nd4j.linalg.api.ops.random.impl.DropOut(sd, input, + inputRetainProbability).outputVariable(); } /** * Dropout operation
      * - * @param name name May be null. Name for the output variable - * @param input Input array (NUMERIC type) - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @param name name May be null. Name for the output variable + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability + * 1-p) * @return output Output (NUMERIC type) */ public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { SDValidation.validateNumerical("dropout", "input", input); - SDVariable out = new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.DropOut(sd, input, + inputRetainProbability).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise exponential linear unit (ELU) function:
      - * out = x if x > 0
      - * out = a * (exp(x) - 1) if x <= 0
      - * with constant a = 1.0
      + * {@code out = x if x > 0 out = a * (exp(x) - 1) if x <= 0 with constant a = 1.0} *


      * See: https://arxiv.org/abs/1511.07289
      * @@ -262,112 +286,107 @@ public class SDNN extends SDOps { */ public SDVariable elu(SDVariable x) { SDValidation.validateNumerical("elu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd, x).outputVariable(); } /** - * Element-wise exponential linear unit (ELU) function:
      - * out = x if x > 0
      - * out = a * (exp(x) - 1) if x <= 0
      - * with constant a = 1.0
      + * Element-wise exponential linear unit (ELU) function:
      out = x if x > 0
      out = a * (exp(x) + * - 1) if x <= 0
      with constant a = 1.0
      *


      * See: https://arxiv.org/abs/1511.07289
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable elu(String name, SDVariable x) { SDValidation.validateNumerical("elu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the sigmoid approximation
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the sigmoid approximation
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable gelu(SDVariable x) { SDValidation.validateNumerical("gelu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd, x).outputVariable(); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the sigmoid approximation
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the sigmoid approximation
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable gelu(String name, SDVariable x) { SDValidation.validateNumerical("gelu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise hard sigmoid function:
      - * out[i] = 0 if in[i] <= -2.5
      - * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
      - * out[i] = 1 if in[i] >= 2.5
      + * Element-wise hard sigmoid function:
      out[i] = 0 if in[i] <= -2.5
      out[1] = 0.2*in[i]+0.5 + * if -2.5 < in[i] < 2.5
      out[i] = 1 if in[i] >= 2.5
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardSigmoid(SDVariable x) { SDValidation.validateNumerical("hardSigmoid", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd, x).outputVariable(); } /** - * Element-wise hard sigmoid function:
      - * out[i] = 0 if in[i] <= -2.5
      - * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
      - * out[i] = 1 if in[i] >= 2.5
      + * Element-wise hard sigmoid function:
      out[i] = 0 if in[i] <= -2.5
      out[1] = 0.2*in[i]+0.5 + * if -2.5 < in[i] < 2.5
      out[i] = 1 if in[i] >= 2.5
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardSigmoid(String name, SDVariable x) { SDValidation.validateNumerical("hardSigmoid", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise hard tanh function:
      - * out[i] = -1 if in[i] <= -1
      - * out[1] = in[i] if -1 < in[i] < 1
      - * out[i] = 1 if in[i] >= 1
      + * Element-wise hard tanh function:
      out[i] = -1 if in[i] <= -1
      out[1] = in[i] if -1 < + * in[i] < 1
      out[i] = 1 if in[i] >= 1
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardTanh(SDVariable x) { SDValidation.validateNumerical("hardTanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd, x).outputVariable(); } /** - * Element-wise hard tanh function:
      - * out[i] = -1 if in[i] <= -1
      - * out[1] = in[i] if -1 < in[i] < 1
      - * out[i] = 1 if in[i] >= 1
      + * Element-wise hard tanh function:
      out[i] = -1 if in[i] <= -1
      out[1] = in[i] if -1 < + * in[i] < 1
      out[i] = 1 if in[i] >= 1
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardTanh(String name, SDVariable x) { SDValidation.validateNumerical("hardTanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -379,19 +398,21 @@ public class SDNN extends SDOps { */ public SDVariable hardTanhDerivative(SDVariable x) { SDValidation.validateNumerical("hardTanhDerivative", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd, + x).outputVariable(); } /** * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardTanhDerivative(String name, SDVariable x) { SDValidation.validateNumerical("hardTanhDerivative", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -400,11 +421,13 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param bias Bias (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, @@ -412,8 +435,11 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); SDValidation.validateNumerical("layerNorm", "bias", bias); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, bias, + channelsFirst, dimensions).outputVariable(); } /** @@ -421,12 +447,14 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param name name May be null. Name for the output variable - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param bias Bias (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, @@ -434,8 +462,11 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); SDValidation.validateNumerical("layerNorm", "bias", bias); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, + bias, channelsFirst, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -444,18 +475,23 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, null, + channelsFirst, dimensions).outputVariable(); } /** @@ -463,111 +499,115 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param name name May be null. Name for the output variable - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, + null, channelsFirst, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise leaky ReLU function:
      - * out = x if x >= 0.0
      - * out = alpha * x if x < cutoff
      + * Element-wise leaky ReLU function:
      out = x if x >= 0.0
      out = alpha * x if x < cutoff
      * Alpha value is most commonly set to 0.01
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyRelu(SDVariable x, double alpha) { SDValidation.validateNumerical("leakyRelu", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd, x, alpha).outputVariable(); } /** - * Element-wise leaky ReLU function:
      - * out = x if x >= 0.0
      - * out = alpha * x if x < cutoff
      + * Element-wise leaky ReLU function:
      out = x if x >= 0.0
      out = alpha * x if x < cutoff
      * Alpha value is most commonly set to 0.01
      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyRelu(String name, SDVariable x, double alpha) { SDValidation.validateNumerical("leakyRelu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd, x, + alpha).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Leaky ReLU derivative: dOut/dIn given input.
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyReluDerivative(SDVariable x, double alpha) { SDValidation.validateNumerical("leakyReluDerivative", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd, x, + alpha).outputVariable(); } /** * Leaky ReLU derivative: dOut/dIn given input.
      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { SDValidation.validateNumerical("leakyReluDerivative", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd, x, + alpha).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Linear layer operation: out = mmul(in,w) + bias
      - * Note that bias array is optional
      + * Linear layer operation: out = mmul(in,w) + bias
      Note that bias array is optional
      * - * @param input Input data (NUMERIC type) + * @param input Input data (NUMERIC type) * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("linear", "input", input); SDValidation.validateNumerical("linear", "weights", weights); SDValidation.validateNumerical("linear", "bias", bias); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd, input, weights, + bias).outputVariable(); } /** - * Linear layer operation: out = mmul(in,w) + bias
      - * Note that bias array is optional
      + * Linear layer operation: out = mmul(in,w) + bias
      Note that bias array is optional
      * - * @param name name May be null. Name for the output variable - * @param input Input data (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("linear", "input", input); SDValidation.validateNumerical("linear", "weights", weights); SDValidation.validateNumerical("linear", "bias", bias); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd, input, weights, + bias).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -579,95 +619,108 @@ public class SDNN extends SDOps { */ public SDVariable logSigmoid(SDVariable x) { SDValidation.validateNumerical("logSigmoid", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd, x).outputVariable(); } /** * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable logSigmoid(String name, SDVariable x) { SDValidation.validateNumerical("logSigmoid", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Log softmax activation
      * - * @param x (NUMERIC type) + * @param x (NUMERIC type) * @return output (NUMERIC type) */ public SDVariable logSoftmax(SDVariable x) { SDValidation.validateNumerical("logSoftmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, x).outputVariable(); } /** * Log softmax activation
      * * @param name name May be null. Name for the output variable - * @param x (NUMERIC type) + * @param x (NUMERIC type) * @return output (NUMERIC type) */ public SDVariable logSoftmax(String name, SDVariable x) { SDValidation.validateNumerical("logSoftmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Log softmax activation
      * - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply log softmax * @return output Output - log(softmax(input)) (NUMERIC type) */ public SDVariable logSoftmax(SDVariable x, int dimension) { SDValidation.validateNumerical("logSoftmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, x, + dimension).outputVariable(); } /** * Log softmax activation
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply log softmax * @return output Output - log(softmax(input)) (NUMERIC type) */ public SDVariable logSoftmax(String name, SDVariable x, int dimension) { SDValidation.validateNumerical("logSoftmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, x, + dimension).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * This performs multi-headed dot product attention on the given timeseries input
      - * out = concat(head_1, head_2, ..., head_n) * Wo
      - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
      + * This performs multi-headed dot product attention on the given timeseries input
      out = + * concat(head_1, head_2, ..., head_n) * Wo
      head_i = dot_product_attention(Wq_i*q, Wk_i*k, + * Wv_i*v)
      *
      * Optionally with normalization when calculating the attention for each head.
      *
      - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
      + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 + * Multi-Head Attention")
      *
      - * This makes use of dot_product_attention OP support for rank 4 inputs.
      - * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      + * This makes use of dot_product_attention OP support for rank 4 inputs.
      see + * dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) - * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) - * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, outSize, queryCount] - * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC + * type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC + * type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC + * type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, + * featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] + * (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] (optionally) + * Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) */ public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, @@ -680,33 +733,43 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd, + queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); } /** - * This performs multi-headed dot product attention on the given timeseries input
      - * out = concat(head_1, head_2, ..., head_n) * Wo
      - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
      + * This performs multi-headed dot product attention on the given timeseries input
      out = + * concat(head_1, head_2, ..., head_n) * Wo
      head_i = dot_product_attention(Wq_i*q, Wk_i*k, + * Wv_i*v)
      *
      * Optionally with normalization when calculating the attention for each head.
      *
      - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
      + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 + * Multi-Head Attention")
      *
      - * This makes use of dot_product_attention OP support for rank 4 inputs.
      - * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      + * This makes use of dot_product_attention OP support for rank 4 inputs.
      see + * dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      * - * @param name name May be null. Name for the output variable - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) - * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) - * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, outSize, queryCount] - * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC + * type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC + * type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC + * type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, + * featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] + * (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] (optionally) + * Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) */ public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, @@ -719,32 +782,34 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention( + sd, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Padding operation
      * - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) - * @param PadMode Padding format + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format * @param constant Padding constant * @return output Padded input (NUMERIC type) */ public SDVariable pad(SDVariable input, SDVariable padding, PadMode PadMode, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, PadMode, + constant).outputVariable(); } /** * Padding operation
      * - * @param name name May be null. Name for the output variable - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) - * @param PadMode Padding format + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format * @param constant Padding constant * @return output Padded input (NUMERIC type) */ @@ -752,233 +817,247 @@ public class SDNN extends SDOps { double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, PadMode, + constant).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Padding operation
      * - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) * @param constant Padding constant * @return output Padded input (NUMERIC type) */ public SDVariable pad(SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, PadMode.CONSTANT, + constant).outputVariable(); } /** * Padding operation
      * - * @param name name May be null. Name for the output variable - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) * @param constant Padding constant * @return output Padded input (NUMERIC type) */ public SDVariable pad(String name, SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, + PadMode.CONSTANT, constant).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the precise method
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the precise method
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable preciseGelu(SDVariable x) { SDValidation.validateNumerical("preciseGelu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd, x).outputVariable(); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the precise method
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the precise method
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable preciseGelu(String name, SDVariable x) { SDValidation.validateNumerical("preciseGelu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
      - * out[i] = in[i] if in[i] >= 0
      - * out[i] = in[i] * alpha[i] otherwise
      + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable + * alpha:
      out[i] = in[i] if in[i] >= 0
      out[i] = in[i] * alpha[i] otherwise
      *
      - * sharedAxes allows you to share learnable parameters along axes.
      - * For example, if the input has shape [batchSize, channels, height, width]
      - * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
      - * alpha with shape [channels].
      + * sharedAxes allows you to share learnable parameters along axes.
      For example, if the input + * has shape [batchSize, channels, height, width]
      and you want each channel to have its own + * cutoff, use sharedAxes = [2, 3] and an
      alpha with shape [channels].
      * - * @param input Input data (NUMERIC type) - * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is + * batch or not) should not be part of alpha. (NUMERIC type) * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) * @return output Output (NUMERIC type) */ public SDVariable prelu(SDVariable input, SDVariable alpha, int... sharedAxes) { SDValidation.validateNumerical("prelu", "input", input); SDValidation.validateNumerical("prelu", "alpha", alpha); - Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); - return new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + Preconditions.checkArgument(sharedAxes.length >= 1, + "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", + sharedAxes.length); + return new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd, input, alpha, + sharedAxes).outputVariable(); } /** - * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
      - * out[i] = in[i] if in[i] >= 0
      - * out[i] = in[i] * alpha[i] otherwise
      + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable + * alpha:
      out[i] = in[i] if in[i] >= 0
      out[i] = in[i] * alpha[i] otherwise
      *
      - * sharedAxes allows you to share learnable parameters along axes.
      - * For example, if the input has shape [batchSize, channels, height, width]
      - * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
      - * alpha with shape [channels].
      + * sharedAxes allows you to share learnable parameters along axes.
      For example, if the input + * has shape [batchSize, channels, height, width]
      and you want each channel to have its own + * cutoff, use sharedAxes = [2, 3] and an
      alpha with shape [channels].
      * - * @param name name May be null. Name for the output variable - * @param input Input data (NUMERIC type) - * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is + * batch or not) should not be part of alpha. (NUMERIC type) * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) * @return output Output (NUMERIC type) */ public SDVariable prelu(String name, SDVariable input, SDVariable alpha, int... sharedAxes) { SDValidation.validateNumerical("prelu", "input", input); SDValidation.validateNumerical("prelu", "alpha", alpha); - Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + Preconditions.checkArgument(sharedAxes.length >= 1, + "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", + sharedAxes.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd, input, alpha, + sharedAxes).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise rectified linear function with specified cutoff:
      - * out[i] = in[i] if in[i] >= cutoff
      - * out[i] = 0 otherwise
      + * {@code out[i] = in[i] if in[i] >= cutoff out[i] = 0 otherwise} * - * @param x Input (NUMERIC type) - * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu(SDVariable x, double cutoff) { SDValidation.validateNumerical("relu", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd, x, cutoff).outputVariable(); } /** - * Element-wise rectified linear function with specified cutoff:
      - * out[i] = in[i] if in[i] >= cutoff
      - * out[i] = 0 otherwise
      + * Element-wise rectified linear function with specified cutoff:
      out[i] = in[i] if in[i] >= + * cutoff
      out[i] = 0 otherwise
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu(String name, SDVariable x, double cutoff) { SDValidation.validateNumerical("relu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd, x, + cutoff).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise "rectified linear 6" function with specified cutoff:
      - * out[i] = min(max(in, cutoff), 6)
      + * Element-wise "rectified linear 6" function with specified cutoff:
      out[i] = min(max(in, + * cutoff), 6)
      * - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @param cutoff Cutoff value for ReLU operation. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu6(SDVariable x, double cutoff) { SDValidation.validateNumerical("relu6", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd, x, cutoff).outputVariable(); } /** - * Element-wise "rectified linear 6" function with specified cutoff:
      - * out[i] = min(max(in, cutoff), 6)
      + * Element-wise "rectified linear 6" function with specified cutoff:
      out[i] = min(max(in, + * cutoff), 6)
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param cutoff Cutoff value for ReLU operation. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu6(String name, SDVariable x, double cutoff) { SDValidation.validateNumerical("relu6", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd, x, cutoff).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      - * Note that bias array is optional
      + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      Note that bias + * array is optional
      * - * @param input Input data (NUMERIC type) + * @param input Input data (NUMERIC type) * @param weights Weights variable (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("reluLayer", "input", input); SDValidation.validateNumerical("reluLayer", "weights", weights); SDValidation.validateNumerical("reluLayer", "bias", bias); - return new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd, input, weights, + bias).outputVariable(); } /** - * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      - * Note that bias array is optional
      + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      Note that bias + * array is optional
      * - * @param name name May be null. Name for the output variable - * @param input Input data (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) * @param weights Weights variable (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("reluLayer", "input", input); SDValidation.validateNumerical("reluLayer", "weights", weights); SDValidation.validateNumerical("reluLayer", "bias", bias); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd, input, weights, + bias).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      *
      - * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      - * Uses default scale and alpha values.
      + * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      Uses default scale + * and alpha values.
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable selu(SDVariable x) { SDValidation.validateNumerical("selu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd, x).outputVariable(); } /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      *
      - * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      - * Uses default scale and alpha values.
      + * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      Uses default scale + * and alpha values.
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable selu(String name, SDVariable x) { SDValidation.validateNumerical("selu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -990,73 +1069,78 @@ public class SDNN extends SDOps { */ public SDVariable sigmoid(SDVariable x) { SDValidation.validateNumerical("sigmoid", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd, x).outputVariable(); } /** * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sigmoid(String name, SDVariable x) { SDValidation.validateNumerical("sigmoid", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
      * - * @param x Input Variable (NUMERIC type) + * @param x Input Variable (NUMERIC type) * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) * @return output Output (gradient at input of sigmoid) (NUMERIC type) */ public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { SDValidation.validateNumerical("sigmoidDerivative", "x", x); SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd, x, + wrt).outputVariable(); } /** * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
      * * @param name name May be null. Name for the output variable - * @param x Input Variable (NUMERIC type) - * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) * @return output Output (gradient at input of sigmoid) (NUMERIC type) */ public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { SDValidation.validateNumerical("sigmoidDerivative", "x", x); SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd, x, + wrt).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Softmax activation, along the specified dimension
      * - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply softmax * @return output Output variable (NUMERIC type) */ public SDVariable softmax(SDVariable x, int dimension) { SDValidation.validateNumerical("softmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, + dimension).outputVariable(); } /** * Softmax activation, along the specified dimension
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply softmax * @return output Output variable (NUMERIC type) */ public SDVariable softmax(String name, SDVariable x, int dimension) { SDValidation.validateNumerical("softmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, + dimension).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1068,49 +1152,52 @@ public class SDNN extends SDOps { */ public SDVariable softmax(SDVariable x) { SDValidation.validateNumerical("softmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, -1).outputVariable(); } /** * Softmax activation, along the specified dimension
      * * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable softmax(String name, SDVariable x) { SDValidation.validateNumerical("softmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, + -1).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Softmax derivative function
      * - * @param x Softmax input (NUMERIC type) - * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) * @param dimension Softmax dimension * @return output (NUMERIC type) */ public SDVariable softmaxDerivative(SDVariable x, SDVariable wrt, int dimension) { SDValidation.validateNumerical("softmaxDerivative", "x", x); SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd, x, wrt, + dimension).outputVariable(); } /** * Softmax derivative function
      * - * @param name name May be null. Name for the output variable - * @param x Softmax input (NUMERIC type) - * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) * @param dimension Softmax dimension * @return output (NUMERIC type) */ public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, int dimension) { SDValidation.validateNumerical("softmaxDerivative", "x", x); SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd, x, wrt, + dimension).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1122,19 +1209,20 @@ public class SDNN extends SDOps { */ public SDVariable softplus(SDVariable x) { SDValidation.validateNumerical("softplus", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd, x).outputVariable(); } /** * Element-wise softplus function: out = log(exp(x) + 1)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable softplus(String name, SDVariable x) { SDValidation.validateNumerical("softplus", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1146,19 +1234,20 @@ public class SDNN extends SDOps { */ public SDVariable softsign(SDVariable x) { SDValidation.validateNumerical("softsign", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd, x).outputVariable(); } /** * Element-wise softsign function: out = x / (abs(x) + 1)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable softsign(String name, SDVariable x) { SDValidation.validateNumerical("softsign", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1170,45 +1259,48 @@ public class SDNN extends SDOps { */ public SDVariable softsignDerivative(SDVariable x) { SDValidation.validateNumerical("softsignDerivative", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd, + x).outputVariable(); } /** * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output (NUMERIC type) */ public SDVariable softsignDerivative(String name, SDVariable x) { SDValidation.validateNumerical("softsignDerivative", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      - * See: https://arxiv.org/abs/1710.05941
      + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      See: https://arxiv.org/abs/1710.05941
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable swish(SDVariable x) { SDValidation.validateNumerical("swish", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd, x).outputVariable(); } /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      - * See: https://arxiv.org/abs/1710.05941
      + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      See: https://arxiv.org/abs/1710.05941
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable swish(String name, SDVariable x) { SDValidation.validateNumerical("swish", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1220,19 +1312,20 @@ public class SDNN extends SDOps { */ public SDVariable tanh(SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, x).outputVariable(); } /** * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable tanh(String name, SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index d57afe876..9c53bda9f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -100,7 +100,7 @@ public class SDRandom extends SDOps { * P(x) = lambda * exp(-lambda * x)
      * * Inputs must satisfy the following constraints:
      - * Must be positive: lambda > 0
      + * Must be positive: lambda > 0
      * * @param lambda lambda parameter * @param datatype Data type of the output variable @@ -118,7 +118,7 @@ public class SDRandom extends SDOps { * P(x) = lambda * exp(-lambda * x)
      * * Inputs must satisfy the following constraints:
      - * Must be positive: lambda > 0
      + * Must be positive: lambda > 0
      * * @param name name May be null. Name for the output variable * @param lambda lambda parameter diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java index 9766a5b7c..a5fe6b751 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java @@ -829,9 +829,9 @@ public class Evaluation extends BaseEvaluation { * Precision based on guesses so far.
      * Note: value returned will differ depending on number of classes and settings.
      * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, - * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class + * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}
      * * @return the total precision based on guesses so far @@ -977,7 +977,7 @@ public class Evaluation extends BaseEvaluation { * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged recall, equivalent to {@code recall(EvaluationAveraging.Macro)}
      * * @return the recall for the outcomes @@ -1173,12 +1173,12 @@ public class Evaluation extends BaseEvaluation { /** * False Alarm Rate (FAR) reflects rate of misclassified to classified records - * {@link }http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}
      + * {@see http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}
      * Note: value returned will differ depending on number of classes and settings.
      * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, - * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class + * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged false alarm rate) * * @return the fpr for the outcomes @@ -1243,9 +1243,9 @@ public class Evaluation extends BaseEvaluation { *
      * Note: value returned will differ depending on number of classes and settings.
      * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, - * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class + * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged f1, equivalent to {@code f1(EvaluationAveraging.Macro)}
      * * @return the f1 score or harmonic mean of precision and recall based on current guesses diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java index c3a66aa93..49a2f8e3d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java @@ -584,7 +584,7 @@ public class EvaluationBinary extends BaseEvaluation { /** * False Alarm Rate (FAR) reflects rate of misclassified to classified records - * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
      + * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
      * * @param outputNum Class index to calculate False Alarm Rate (FAR) * @return The FAR for the outcomes @@ -611,7 +611,7 @@ public class EvaluationBinary extends BaseEvaluation { StringBuilder sb = new StringBuilder(); - //Report: Accuracy, precision, recall, F1. Then: confusion matrix + //Report: Accuracy, precision, recall, F1. Then: confusion matrix] int maxLabelsLength = 15; if (labels != null) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java index 3fc641421..5503fdcf0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java @@ -202,7 +202,7 @@ public abstract class BaseLapack implements Lapack { * * @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors * @param uplo upper or lower part of symmetric matrix to use - * @param N the number of rows & cols in the matrix A + * @param N the number of rows & cols in the matrix A * @param A the matrix to calculate eigenvectors * @param R an output array for eigenvalues ( may be null ) */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java index e21b2cd51..8789de3ff 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java @@ -74,14 +74,14 @@ public interface DataBufferFactory { DataBuffer createDouble(long offset, int length); /** - * This method will create new DataBuffer of the same dataType & same length + * This method will create new DataBuffer of the same dataType & same length * @param buffer * @return */ DataBuffer createSame(DataBuffer buffer, boolean init); /** - * This method will create new DataBuffer of the same dataType & same length + * This method will create new DataBuffer of the same dataType & same length * @param buffer * @return */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java index bdf18ed35..32f5503ea 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java @@ -132,7 +132,6 @@ public interface MemoryManager { * * @param pointer * @param kind - * @return */ void release(Pointer pointer, MemoryKind kind); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java index bade5db59..b9a2a0622 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java @@ -137,7 +137,7 @@ public interface MemoryWorkspaceManager { void destroyWorkspace(MemoryWorkspace workspace); /** - * This method destroys & deallocates all Workspaces for a calling Thread + * This method destroys & deallocates all Workspaces for a calling Thread * * PLEASE NOTE: This method is NOT safe */ @@ -149,21 +149,21 @@ public interface MemoryWorkspaceManager { void destroyWorkspace(); /** - * This method gets & activates default workspace + * This method gets and activates default workspace * * @return */ MemoryWorkspace getAndActivateWorkspace(); /** - * This method gets & activates workspace with a given Id + * This method gets and activates workspace with a given Id * * @return */ MemoryWorkspace getAndActivateWorkspace(String id); /** - * This method gets & activates default with a given configuration and Id + * This method gets and activates default with a given configuration and Id * * @return */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 9887ddadb..53ce1beba 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -94,5537 +94,5899 @@ import static org.nd4j.linalg.factory.Nd4j.*; @Slf4j public abstract class BaseNDArray implements INDArray, Iterable { - private static final long serialVersionUID = 3285982317165542614L; + private static final long serialVersionUID = 3285982317165542614L; - protected transient volatile DataBuffer shapeInformation; - protected transient volatile DataBuffer data; - //protected transient DataBuffer shape; - //protected transient DataBuffer stride; - protected transient boolean compressed = false; + protected transient volatile DataBuffer shapeInformation; + protected transient volatile DataBuffer data; + //protected transient DataBuffer shape; + //protected transient DataBuffer stride; + protected transient boolean compressed = false; - protected transient boolean released = false; + protected transient boolean released = false; - // this field holds jvm copy of shapeInfo - protected transient JvmShapeInfo jvmShapeInfo; + // this field holds jvm copy of shapeInfo + protected transient JvmShapeInfo jvmShapeInfo; - private static final AtomicLong arrayCounter = new AtomicLong(0); - protected transient final long arrayId = arrayCounter.getAndIncrement(); + private static final AtomicLong arrayCounter = new AtomicLong(0); + protected transient final long arrayId = arrayCounter.getAndIncrement(); - //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over - private static final int[][] tadFinalPermuteDimensions; - static { - tadFinalPermuteDimensions = new int[32][0]; - tadFinalPermuteDimensions[1] = new int[] {1, 0}; //Edge case for 1d tensors: selectively apply to column vectors - for (int i = 2; i < 32; i++) { - tadFinalPermuteDimensions[i] = new int[i]; - for (int k = i - 1, j = 0; k >= 0; k--, j++) - tadFinalPermuteDimensions[i][j] = k; - } - val t =1; - } + //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over + private static final int[][] tadFinalPermuteDimensions; - public BaseNDArray() { - } - - @Override - public boolean isCompressed() { - return compressed; - } - - @Override - public void markAsCompressed(boolean reallyCompressed) { - this.compressed = reallyCompressed; - } - - /** - * - * @param buffer - */ - public BaseNDArray(DataBuffer buffer) { - this.data = buffer; - if (buffer.length() >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); - long[] shape = {1, (int) buffer.length()}; - long[] stride = Nd4j.getStrides(shape); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 1, Nd4j.order(), buffer.dataType(), false)); - init(shape, stride); - } - - /** - * - * @param buffer - * @param shape - * @param stride - * @param offset - * @param ordering - */ - public BaseNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, buffer.dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering) { - this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering) { - Shape.assertValidOrder(ordering); - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, buffer.dataType(), false )); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering, DataType dataType) { - this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering, DataType dataType) { - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, dataType, false)); - init(shape, stride); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type) { - this.data = buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type, MemoryWorkspace workspace) { - this.data = buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - public BaseNDArray(DataBuffer buffer, DataType dataType, long[] shape, long[] stride, long offset, char ordering) { - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType, false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - /** - * Initialize the ndarray as a matrix - * with the given data (indices preserved) - * @param data - */ - public BaseNDArray(double[][] data) { - this(data, Nd4j.order()); - } - - /** - * - * @param data - * @param ordering - */ - public BaseNDArray(double[][] data, char ordering) { - this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), - new int[] {data.length, data[0].length}, - Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering); - - int c = columns(); - for (int r = 0; r < rows(); r++) { - Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c ); + static { + tadFinalPermuteDimensions = new int[32][0]; + tadFinalPermuteDimensions[1] = new int[]{1, + 0}; //Edge case for 1d tensors: selectively apply to column vectors + for (int i = 2; i < 32; i++) { + tadFinalPermuteDimensions[i] = new int[i]; + for (int k = i - 1, j = 0; k >= 0; k--, j++) { + tadFinalPermuteDimensions[i][j] = k; } } + val t = 1; + } + public BaseNDArray() { + } - /** - * Create with the specified shape and buffer - * - * @param shape the shape - * @param buffer the buffer - */ - public BaseNDArray(int[] shape, DataBuffer buffer) { - this.data = buffer; - init(shape, Nd4j.getStrides(shape)); - } - - /** - * Create this ndarray with the given data and shape and 0 offset - * - * @param data the data to use - * @param shape the shape of the ndarray - */ - public BaseNDArray(float[] data, int[] shape, char ordering) { - this(data, shape, 0, ordering); - } - - /** - * @param data the data to use - * @param shape the shape of the ndarray - * @param offset the desired offset - * @param ordering the ordering of the ndarray - */ - public BaseNDArray(float[] data, int[] shape, long offset, char ordering) { - this(data, shape, Nd4j.getStrides(shape, ordering), offset); - } - - public BaseNDArray(double[] data, long[] shape, long offset, char ordering) { - this(data, shape, Nd4j.getStrides(shape, ordering), offset); - } - - public BaseNDArray(float[] data, long[] shape, long offset, char ordering) { - this(data, shape, Nd4j.getStrides(shape, ordering), offset); + @Override + public boolean isCompressed() { + return compressed; + } + + @Override + public void markAsCompressed(boolean reallyCompressed) { + this.compressed = reallyCompressed; + } + + /** + * @param buffer + */ + public BaseNDArray(DataBuffer buffer) { + this.data = buffer; + if (buffer.length() >= Integer.MAX_VALUE) { + throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); + } + long[] shape = {1, (int) buffer.length()}; + long[] stride = Nd4j.getStrides(shape); + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, 1, Nd4j.order(), buffer.dataType(), false)); + init(shape, stride); + } + + /** + * @param buffer + * @param shape + * @param stride + * @param offset + * @param ordering + */ + public BaseNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, buffer.dataType(), + false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering) { + this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), + ordering); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, + char ordering) { + Shape.assertValidOrder(ordering); + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, ews, ordering, buffer.dataType(), false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering, + DataType dataType) { + this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), + ordering, dataType); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, + char ordering, DataType dataType) { + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, ews, ordering, dataType, false)); + init(shape, stride); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type) { + this.data = buffer; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type, + MemoryWorkspace workspace) { + this.data = buffer; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + public BaseNDArray(DataBuffer buffer, DataType dataType, long[] shape, long[] stride, long offset, + char ordering) { + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType, false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + /** + * Initialize the ndarray as a matrix with the given data (indices preserved) + * + * @param data + */ + public BaseNDArray(double[][] data) { + this(data, Nd4j.order()); + } + + /** + * @param data + * @param ordering + */ + public BaseNDArray(double[][] data, char ordering) { + this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), + new int[]{data.length, data[0].length}, + Nd4j.getStrides(new int[]{data.length, data[0].length}, ordering), 0, ordering); + + int c = columns(); + for (int r = 0; r < rows(); r++) { + Preconditions.checkState(data[r].length == c, + "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c); } + } - /** - * Construct an ndarray of the specified shape - * with an empty data array - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param offset the desired offset - * @param ordering the ordering of the ndarray - */ - public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); - } + /** + * Create with the specified shape and buffer + * + * @param shape the shape + * @param buffer the buffer + */ + public BaseNDArray(int[] shape, DataBuffer buffer) { + this.data = buffer; + init(shape, Nd4j.getStrides(shape)); + } - public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); - } + /** + * Create this ndarray with the given data and shape and 0 offset + * + * @param data the data to use + * @param shape the shape of the ndarray + */ + public BaseNDArray(float[] data, int[] shape, char ordering) { + this(data, shape, 0, ordering); + } - /** - * Construct an ndarray of the specified shape. - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param offset the desired offset - * @param ordering the ordering of the ndarray - * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't. - */ - public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); - } + /** + * @param data the data to use + * @param shape the shape of the ndarray + * @param offset the desired offset + * @param ordering the ordering of the ndarray + */ + public BaseNDArray(float[] data, int[] shape, long offset, char ordering) { + this(data, shape, Nd4j.getStrides(shape, ordering), offset); + } - public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); - } + public BaseNDArray(double[] data, long[] shape, long offset, char ordering) { + this(data, shape, Nd4j.getStrides(shape, ordering), offset); + } - public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering); - } + public BaseNDArray(float[] data, long[] shape, long offset, char ordering) { + this(data, shape, Nd4j.getStrides(shape, ordering), offset); + } - public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize, MemoryWorkspace workspace) { - this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering); - } - public BaseNDArray(DataType type, long[] shape, long[] paddings, long[] paddingOffsets, char ordering, MemoryWorkspace workspace) { + /** + * Construct an ndarray of the specified shape with an empty data array + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param offset the desired offset + * @param ordering the ordering of the ndarray + */ + public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, + offset, ordering); + } - //calculate strides with paddings - int rank = shape.length; - if(paddings == null || paddings.length != rank ) throw new IllegalArgumentException("The length of Padding should be equal to the length of Shape"); - long [] paddedShape = new long[rank]; - boolean empty = false; - boolean zeroOffset = paddingOffsets == null || paddingOffsets.length == 0; - boolean paddingOffsetsInvalid = paddingOffsets != null && paddingOffsets.length != rank ; - long ews = 1; - if(!paddingOffsetsInvalid){ - for(int i=0; ipaddings[i]){ - paddingOffsetsInvalid = true; - break; - } - } + public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, + offset, ordering); + } + + /** + * Construct an ndarray of the specified shape. + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param offset the desired offset + * @param ordering the ordering of the ndarray + * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't. + */ + public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, + stride, offset, ordering); + } + + public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, + stride, offset, ordering); + } + + public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, + boolean initialize) { + this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), + type, shape, stride, offset, ordering); + } + + public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, + boolean initialize, MemoryWorkspace workspace) { + this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, + workspace), type, shape, stride, offset, ordering); + } + + public BaseNDArray(DataType type, long[] shape, long[] paddings, long[] paddingOffsets, + char ordering, MemoryWorkspace workspace) { + + //calculate strides with paddings + int rank = shape.length; + if (paddings == null || paddings.length != rank) { + throw new IllegalArgumentException( + "The length of Padding should be equal to the length of Shape"); + } + long[] paddedShape = new long[rank]; + boolean empty = false; + boolean zeroOffset = paddingOffsets == null || paddingOffsets.length == 0; + boolean paddingOffsetsInvalid = paddingOffsets != null && paddingOffsets.length != rank; + long ews = 1; + if (!paddingOffsetsInvalid) { + for (int i = 0; i < rank; i++) { + paddedShape[i] = shape[i] + paddings[i]; + if (paddings[i] != 0) { + ews = 0; + } + if (shape[i] == 0) { + empty = true; + } + if (paddingOffsets[i] > paddings[i]) { + paddingOffsetsInvalid = true; + break; } - if(!zeroOffset && paddingOffsetsInvalid) throw new IllegalArgumentException("If PaddingOffsets is not empty or zero length then its length should match the length of Paddings and also its elements should not be greater"); - - long[] paddedStride = ordering == 'c' ? ArrayUtil.calcStrides(paddedShape,1): ArrayUtil.calcStridesFortran(paddedShape,1); - long paddedAllocSize = ordering == 'c' ? paddedShape[0] * paddedStride[0] : paddedShape[rank-1] * paddedStride[rank-1]; - - long offset = (empty || ews == 1 || zeroOffset) ? 0 : ArrayUtil.calcOffset(paddedShape, paddingOffsets, paddedStride); - DataBuffer buffer = Nd4j.createBuffer(type, paddedAllocSize, false, workspace); - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, paddedAllocSize - offset) : buffer ; - long extras = ArrayOptionsHelper.setOptionBit(0, type); - if(empty) extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.ATYPE_EMPTY_BIT); - else if(ews!=1) extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.HAS_PADDED_BUFFER); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, paddedStride, ews, ordering, extras)); + } } + if (!zeroOffset && paddingOffsetsInvalid) { + throw new IllegalArgumentException( + "If PaddingOffsets is not empty or zero length then its length should match the length of Paddings and also its elements should not be greater"); + } - /** - * Create the ndarray with - * the specified shape and stride and an offset of 0 - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param ordering the ordering of the ndarray - */ - public BaseNDArray(int[] shape, int[] stride, char ordering) { - this(shape, stride, 0, ordering); + long[] paddedStride = ordering == 'c' ? ArrayUtil.calcStrides(paddedShape, 1) + : ArrayUtil.calcStridesFortran(paddedShape, 1); + long paddedAllocSize = ordering == 'c' ? paddedShape[0] * paddedStride[0] + : paddedShape[rank - 1] * paddedStride[rank - 1]; + + long offset = (empty || ews == 1 || zeroOffset) ? 0 + : ArrayUtil.calcOffset(paddedShape, paddingOffsets, paddedStride); + DataBuffer buffer = Nd4j.createBuffer(type, paddedAllocSize, false, workspace); + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, paddedAllocSize - offset) : buffer; + long extras = ArrayOptionsHelper.setOptionBit(0, type); + if (empty) { + extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.ATYPE_EMPTY_BIT); + } else if (ews != 1) { + extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.HAS_PADDED_BUFFER); + } + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, paddedStride, ews, ordering, extras)); + } + + /** + * Create the ndarray with the specified shape and stride and an offset of 0 + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param ordering the ordering of the ndarray + */ + public BaseNDArray(int[] shape, int[] stride, char ordering) { + this(shape, stride, 0, ordering); + } + + + /** + * @param shape + * @param offset + * @param ordering + */ + public BaseNDArray(int[] shape, long offset, char ordering) { + this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + public BaseNDArray(long[] shape, long offset, char ordering) { + this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + + /** + * Create an ndarray with the given shape + * + * @param shape + */ + public BaseNDArray(int[] shape) { + this(shape, 0, Nd4j.order()); + } + + public BaseNDArray(long[] shape) { + this(shape, 0, Nd4j.order()); + } + + + /** + * Creates a new n times m DoubleMatrix. + * + * @param newRows the number of rows (n) of the new matrix. + * @param newColumns the number of columns (m) of the new matrix. + */ + public BaseNDArray(int newRows, int newColumns, char ordering) { + Shape.assertValidOrder(ordering); + this.data = Nd4j.createBuffer((long) newRows * newColumns); + val shape = new long[]{newRows, newColumns}; + val stride = Nd4j.getStrides(shape, ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); + init(shape, stride); + } + + public BaseNDArray(long newRows, long newColumns, char ordering) { + Shape.assertValidOrder(ordering); + this.data = Nd4j.createBuffer(newRows * newColumns); + long[] shape = new long[]{newRows, newColumns}; + long[] stride = Nd4j.getStrides(shape, ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); + init(shape, stride); + } + + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape, char ordering) { + this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); + } + + public BaseNDArray(List slices, long[] shape, char ordering) { + this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); + } + + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) { + Shape.assertValidOrder(ordering); + DataBuffer ret = slices.get(0).data().dataType() == (DataType.FLOAT) + ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) + : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); + this.data = ret; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, + slices.get(0).dataType(), false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + + if (slices.get(0).isScalar()) { + for (int i = 0; i < length(); i++) { + putScalar(i, slices.get(i).getDouble(0)); + } + } else { + for (int i = 0; i < slices(); i++) { + putSlice(i, slices.get(i)); + } } + } - /** - * - * @param shape - * @param offset - * @param ordering - */ - public BaseNDArray(int[] shape, long offset, char ordering) { - this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); - } - - public BaseNDArray(long[] shape, long offset, char ordering) { - this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); - } - - - /** - * Create an ndarray - * with the given shape - * @param shape - */ - public BaseNDArray(int[] shape) { - this(shape, 0, Nd4j.order()); - } - - public BaseNDArray(long[] shape) { - this(shape, 0, Nd4j.order()); - } - - - /** - * Creates a new n times m DoubleMatrix. - * - * @param newRows the number of rows (n) of the new matrix. - * @param newColumns the number of columns (m) of the new matrix. - */ - public BaseNDArray(int newRows, int newColumns, char ordering) { - Shape.assertValidOrder(ordering); - this.data = Nd4j.createBuffer((long) newRows * newColumns); - val shape = new long[] {newRows, newColumns}; - val stride = Nd4j.getStrides(shape, ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); - init(shape, stride); - } - - public BaseNDArray(long newRows, long newColumns, char ordering) { - Shape.assertValidOrder(ordering); - this.data = Nd4j.createBuffer(newRows * newColumns); - long[] shape = new long[] {newRows, newColumns}; - long[] stride = Nd4j.getStrides(shape, ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); - init(shape, stride); - } - - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape, char ordering) { - this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); - } - - public BaseNDArray(List slices, long[] shape, char ordering) { - this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); - } - - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) { - Shape.assertValidOrder(ordering); - DataBuffer ret = slices.get(0).data().dataType() == (DataType.FLOAT) - ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) - : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); - this.data = ret; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - - if (slices.get(0).isScalar()) { - for (int i = 0; i < length(); i++) { - putScalar(i, slices.get(i).getDouble(0)); - } - } else { - for (int i = 0; i < slices(); i++) { - putSlice(i, slices.get(i)); - } - } - } - - - public BaseNDArray(List slices, long[] shape, long[] stride, char ordering) { - DataBuffer ret = Nd4j.createBuffer(slices.get(0).dataType(), Shape.lengthOf(shape), false); /*slices.get(0).data().dataType() == (DataType.FLOAT) + public BaseNDArray(List slices, long[] shape, long[] stride, char ordering) { + DataBuffer ret = Nd4j.createBuffer(slices.get(0).dataType(), Shape.lengthOf(shape), false); /*slices.get(0).data().dataType() == (DataType.FLOAT) ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); */ - this.data = ret; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + this.data = ret; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), + false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - if (slices.get(0).isScalar()) { - for (int i = 0; i < length(); i++) { - putScalar(i, slices.get(i).getDouble(0)); - } - } else { - for (int i = 0; i < slices(); i++) { - putSlice(i, slices.get(i)); - } + if (slices.get(0).isScalar()) { + for (int i = 0; i < length(); i++) { + putScalar(i, slices.get(i).getDouble(0)); + } + } else { + for (int i = 0; i < slices(); i++) { + putSlice(i, slices.get(i)); + } + } + } + + /** + * @param data + * @param shape + * @param stride + * @param ordering + */ + public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) { + this(data, shape, stride, 0, ordering); + } + + /** + * @param data + * @param shape + * @param stride + * @param offset + * @param ordering + */ + public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, + data == null || data.length <= 0)); + if (data != null && data.length > 0) { + + val perfD = PerformanceTracker.getInstance().helperStartTransaction(); + + this.data = internalCreateBuffer(data, offset); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, + (long) data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST); + + if (offset >= data.length) { + throw new IllegalArgumentException("invalid offset: must be < data.length"); } } - /** - * - * @param data - * @param shape - * @param stride - * @param ordering - */ - public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) { - this(data, shape, stride, 0, ordering); - } - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - * @param ordering - */ - public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data == null || data.length <= 0)); - if (data != null && data.length > 0) { - - val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - - this.data = internalCreateBuffer(data, offset); - - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, (long) data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST); - - if (offset >= data.length) - throw new IllegalArgumentException("invalid offset: must be < data.length"); - } - - init(shape, stride); - } - - public BaseNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data == null || data.length <= 0)); - if (data != null && data.length > 0) { - this.data = Nd4j.createTypedBuffer(data, DataType.FLOAT); - if (offset >= data.length) - throw new IllegalArgumentException("invalid offset: must be < data.length"); - } - - init(shape, stride); - } - - public BaseNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, data == null || data.length <= 0)); - if (data != null && data.length > 0) { - this.data = Nd4j.createBuffer(data, offset); - if (offset >= data.length) - throw new IllegalArgumentException("invalid offset: must be < data.length"); - } - - init(shape, stride); - } - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - */ - public BaseNDArray(DataBuffer data, int[] shape, int[] stride, long offset) { - this.data = Nd4j.createBuffer(data, offset, ArrayUtil.prodLong(shape)); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'), Nd4j.order(), data.dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f')); - - - } - - /** - * - * @param data - * @param shape - * @param strides - */ - public BaseNDArray(int[] data, int[] shape, int[] strides) { - this(internalCreateBuffer(data), shape, strides); - } - - /** - * - * @param data - * @param shape - */ - public BaseNDArray(DataBuffer data, int[] shape) { - this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); - } - - public BaseNDArray(DataBuffer data, long[] shape) { - this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); - } - - - /** - * - * @param buffer - * @param shape - * @param offset - */ - public BaseNDArray(DataBuffer buffer, int[] shape, long offset) { - this(Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)), shape, Nd4j.getStrides(shape), offset, - Nd4j.order()); - } - - /** - * - * @param buffer - * @param shape - * @param ordering - */ - public BaseNDArray(DataBuffer buffer, int[] shape, char ordering) { - this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, char ordering) { - this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); - } - - /** - * - * @param data - * @param shape - * @param ordering - */ - public BaseNDArray(double[] data, int[] shape, char ordering) { - this(Nd4j.createBuffer(data), shape, ordering); - } - - public BaseNDArray(double[] data, long[] shape, char ordering) { - this(Nd4j.createBuffer(data), shape, ordering); - } - - public BaseNDArray(float[] data, long[] shape, char ordering) { - this(Nd4j.createBuffer(data), shape, ordering); - } - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - * @param ordering - */ - public BaseNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) { - this(internalCreateBuffer(data, offset), shape, stride, offset, ordering); - } - - /** - * - * @param data - * @param order - */ - public BaseNDArray(float[] data, char order) { - this(internalCreateBuffer(data), order); - } - - protected static DataBuffer internalCreateBuffer(float[] data) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(double[] data) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(int[] data) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(float[] data, long offset) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data, offset); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(double[] data, long offset) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data, offset); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - /** - * - * @param floatBuffer - * @param order - */ - public BaseNDArray(DataBuffer floatBuffer, char order) { - this(floatBuffer, new int[] {(int) floatBuffer.length()}, - Nd4j.getStrides(new int[] {(int) floatBuffer.length()}, order), 0, order); - Shape.assertValidOrder(order); - if (floatBuffer.length() >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); - } - - /** - * - * @param buffer - * @param shape - * @param strides - */ - public BaseNDArray(DataBuffer buffer, int[] shape, int[] strides) { - this(buffer, shape, strides, 0, Nd4j.order()); - } - - - /** - * Create this ndarray with the given data and shape and 0 offset - * - * @param data the data to use - * @param shape the shape of the ndarray - */ - public BaseNDArray(float[] data, int[] shape) { - this(data, shape, 0); - } - - - /** - * - * @param data - * @param shape - * @param offset - */ - public BaseNDArray(float[] data, int[] shape, long offset) { - this(data, shape, offset, Nd4j.order()); - - } - - /** - * Construct an ndarray of the specified shape - * with an empty data array - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param offset the desired offset - */ - public BaseNDArray(int[] shape, int[] stride, long offset) { - this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); - } - - public BaseNDArray(long[] shape, long[] stride, long offset) { - this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); - } - - /** - * Create the ndarray with - * the specified shape and stride and an offset of 0 - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - */ - public BaseNDArray(int[] shape, int[] stride) { - this(shape, stride, 0); - } - - /** - * - * @param shape - * @param offset - */ - public BaseNDArray(int[] shape, long offset) { - this(shape, Nd4j.getStrides(shape), offset); - } - - /** - * - * @param shape - * @param ordering - */ - public BaseNDArray(int[] shape, char ordering) { - this(shape, 0, ordering); - } - - - /** - * Creates a new n times m DoubleMatrix. - * - * @param newRows the number of rows (n) of the new matrix. - * @param newColumns the number of columns (m) of the new matrix. - */ - public BaseNDArray(int newRows, int newColumns) { - this(newRows, newColumns, Nd4j.order()); - } - - public BaseNDArray(long newRows, long newColumns) { - this(newRows, newColumns, Nd4j.order()); - } - - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape) { - this(slices, shape, Nd4j.order()); - } - - public BaseNDArray(List slices, long[] shape) { - this(slices, shape, Nd4j.order()); - } - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape, int[] stride) { - this(slices, shape, stride, Nd4j.order()); - } - - public BaseNDArray(List slices, long[] shape, long[] stride) { - this(slices, shape, stride, Nd4j.order()); - } - - /** - * - * @param data - * @param shape - * @param stride - */ - public BaseNDArray(float[] data, int[] shape, int[] stride) { - this(data, shape, stride, Nd4j.order()); - } - - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - */ - public BaseNDArray(float[] data, int[] shape, int[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); - } - - public BaseNDArray(double[] data, long[] shape, long[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); - } - - public BaseNDArray(float[] data, long[] shape, long[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); - } - - /** - * - * @param data - */ - public BaseNDArray(float[] data) { - this(Nd4j.createBuffer(data)); - } - - - /** - * Initialize the ndarray - * with the given data - * @param data - */ - public BaseNDArray(float[][] data) { - this(data, Nd4j.order()); - } - - /** - * - * @param data - * @param ordering - */ - public BaseNDArray(float[][] data, char ordering) { - this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), - new int[] {data.length, data[0].length}, - Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering); - - int c = columns(); - for (int r = 0; r < rows(); r++) { - Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c ); + init(shape, stride); + } + + public BaseNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, + data == null || data.length <= 0)); + if (data != null && data.length > 0) { + this.data = Nd4j.createTypedBuffer(data, DataType.FLOAT); + if (offset >= data.length) { + throw new IllegalArgumentException("invalid offset: must be < data.length"); } } + init(shape, stride); + } - - /** - * Constructor for stride and offset - * - * @param buffer - * @param shape - * @param offset - * @param ordering - */ - public BaseNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) { - this(buffer, shape, Nd4j.getStrides(shape, ordering), offset, ordering); + public BaseNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, + data == null || data.length <= 0)); + if (data != null && data.length > 0) { + this.data = Nd4j.createBuffer(data, offset); + if (offset >= data.length) { + throw new IllegalArgumentException("invalid offset: must be < data.length"); + } } - public BaseNDArray(double[] data, int[] shape, int[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); + init(shape, stride); + } + + /** + * @param data + * @param shape + * @param stride + * @param offset + */ + public BaseNDArray(DataBuffer data, int[] shape, int[] stride, long offset) { + this.data = Nd4j.createBuffer(data, offset, ArrayUtil.prodLong(shape)); + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'), Nd4j.order(), + data.dataType(), false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f')); + + } + + /** + * @param data + * @param shape + * @param strides + */ + public BaseNDArray(int[] data, int[] shape, int[] strides) { + this(internalCreateBuffer(data), shape, strides); + } + + /** + * @param data + * @param shape + */ + public BaseNDArray(DataBuffer data, int[] shape) { + this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); + } + + public BaseNDArray(DataBuffer data, long[] shape) { + this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); + } + + + /** + * @param buffer + * @param shape + * @param offset + */ + public BaseNDArray(DataBuffer buffer, int[] shape, long offset) { + this(Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)), shape, + Nd4j.getStrides(shape), offset, + Nd4j.order()); + } + + /** + * @param buffer + * @param shape + * @param ordering + */ + public BaseNDArray(DataBuffer buffer, int[] shape, char ordering) { + this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, char ordering) { + this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); + } + + /** + * @param data + * @param shape + * @param ordering + */ + public BaseNDArray(double[] data, int[] shape, char ordering) { + this(Nd4j.createBuffer(data), shape, ordering); + } + + public BaseNDArray(double[] data, long[] shape, char ordering) { + this(Nd4j.createBuffer(data), shape, ordering); + } + + public BaseNDArray(float[] data, long[] shape, char ordering) { + this(Nd4j.createBuffer(data), shape, ordering); + } + + /** + * @param data + * @param shape + * @param stride + * @param offset + * @param ordering + */ + public BaseNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) { + this(internalCreateBuffer(data, offset), shape, stride, offset, ordering); + } + + /** + * @param data + * @param order + */ + public BaseNDArray(float[] data, char order) { + this(internalCreateBuffer(data), order); + } + + protected static DataBuffer internalCreateBuffer(float[] data) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(double[] data) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(int[] data) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(float[] data, long offset) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data, offset); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(double[] data, long offset) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data, offset); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + /** + * @param floatBuffer + * @param order + */ + public BaseNDArray(DataBuffer floatBuffer, char order) { + this(floatBuffer, new int[]{(int) floatBuffer.length()}, + Nd4j.getStrides(new int[]{(int) floatBuffer.length()}, order), 0, order); + Shape.assertValidOrder(order); + if (floatBuffer.length() >= Integer.MAX_VALUE) { + throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); + } + } + + /** + * @param buffer + * @param shape + * @param strides + */ + public BaseNDArray(DataBuffer buffer, int[] shape, int[] strides) { + this(buffer, shape, strides, 0, Nd4j.order()); + } + + + /** + * Create this ndarray with the given data and shape and 0 offset + * + * @param data the data to use + * @param shape the shape of the ndarray + */ + public BaseNDArray(float[] data, int[] shape) { + this(data, shape, 0); + } + + + /** + * @param data + * @param shape + * @param offset + */ + public BaseNDArray(float[] data, int[] shape, long offset) { + this(data, shape, offset, Nd4j.order()); + + } + + /** + * Construct an ndarray of the specified shape with an empty data array + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param offset the desired offset + */ + public BaseNDArray(int[] shape, int[] stride, long offset) { + this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); + } + + public BaseNDArray(long[] shape, long[] stride, long offset) { + this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); + } + + /** + * Create the ndarray with the specified shape and stride and an offset of 0 + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + */ + public BaseNDArray(int[] shape, int[] stride) { + this(shape, stride, 0); + } + + /** + * @param shape + * @param offset + */ + public BaseNDArray(int[] shape, long offset) { + this(shape, Nd4j.getStrides(shape), offset); + } + + /** + * @param shape + * @param ordering + */ + public BaseNDArray(int[] shape, char ordering) { + this(shape, 0, ordering); + } + + + /** + * Creates a new n times m DoubleMatrix. + * + * @param newRows the number of rows (n) of the new matrix. + * @param newColumns the number of columns (m) of the new matrix. + */ + public BaseNDArray(int newRows, int newColumns) { + this(newRows, newColumns, Nd4j.order()); + } + + public BaseNDArray(long newRows, long newColumns) { + this(newRows, newColumns, Nd4j.order()); + } + + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape) { + this(slices, shape, Nd4j.order()); + } + + public BaseNDArray(List slices, long[] shape) { + this(slices, shape, Nd4j.order()); + } + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape, int[] stride) { + this(slices, shape, stride, Nd4j.order()); + } + + public BaseNDArray(List slices, long[] shape, long[] stride) { + this(slices, shape, stride, Nd4j.order()); + } + + /** + * @param data + * @param shape + * @param stride + */ + public BaseNDArray(float[] data, int[] shape, int[] stride) { + this(data, shape, stride, Nd4j.order()); + } + + + /** + * @param data + * @param shape + * @param stride + * @param offset + */ + public BaseNDArray(float[] data, int[] shape, int[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + public BaseNDArray(double[] data, long[] shape, long[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + public BaseNDArray(float[] data, long[] shape, long[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + /** + * @param data + */ + public BaseNDArray(float[] data) { + this(Nd4j.createBuffer(data)); + } + + + /** + * Initialize the ndarray with the given data + * + * @param data + */ + public BaseNDArray(float[][] data) { + this(data, Nd4j.order()); + } + + /** + * @param data + * @param ordering + */ + public BaseNDArray(float[][] data, char ordering) { + this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), + new int[]{data.length, data[0].length}, + Nd4j.getStrides(new int[]{data.length, data[0].length}, ordering), 0, ordering); + + int c = columns(); + for (int r = 0; r < rows(); r++) { + Preconditions.checkState(data[r].length == c, + "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c); + } + } + + + /** + * Constructor for stride and offset + * + * @param buffer + * @param shape + * @param offset + * @param ordering + */ + public BaseNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) { + this(buffer, shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + public BaseNDArray(double[] data, int[] shape, int[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + + /** + * Returns whether the ndarray is valid or not + * + * @return true if the ndarray is valid false otherwise + */ + @Deprecated + public boolean isValid() { + try { + linearIndex(length() - 1); + } catch (Exception e) { + return false; + } + return true; + } + + protected INDArray create(DataBuffer data, int[] shape, long offset) { + return Nd4j.create(data, shape, offset); + } + + @Override + public int elementWiseStride() { + return Shape.elementWiseStride(shapeInfoDataBuffer()); + } + + @Override + public long tensorsAlongDimension(int... dimension) { + if (dimension == null || dimension.length == 0) { + throw new IllegalArgumentException( + "Invalid input: dimensions not specified (null or length 0)"); + } + if (dimension.length >= rank() + || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) { + return 1; + } + for (int i = 0; i < dimension.length; i++) { + if (dimension[i] < 0) { + dimension[i] += rank(); + } + } + long[] tensorShape = ArrayUtil.keep(shape(), dimension); + long len = ArrayUtil.prodLong(tensorShape); + if (len == 0) { + throw new IllegalStateException("Illegal length found after removing index"); + } + long length = length(); + if (length / len >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Tensors along dimension can not be >= Integer.MAX_VALUE"); + } + return length / len; + } + + @Override + public INDArray tensorAlongDimension(long index, int... dimension) { + if (dimension == null || dimension.length == 0) { + throw new IllegalArgumentException( + "Invalid input: dimensions not specified (null or length 0)"); + } + + Preconditions.checkArgument(!this.isEmpty(), + "tensorAlongDimension(...) can't be used on empty tensors"); + + if (dimension.length >= rank() + || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) { + return this; + } + for (int i = 0; i < dimension.length; i++) { + if (dimension[i] < 0) { + dimension[i] += rank(); + } + } + + //dedup + if (dimension.length > 1) { + dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension)))); + } + + if (dimension.length > 1) { + Arrays.sort(dimension); } + long tads = tensorsAlongDimension(dimension); + if (index >= tads) { + throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); + } - /** - * Returns whether the ndarray is valid or not - * @return true if the ndarray is valid - * false otherwise - */ - @Deprecated - public boolean isValid() { - try { - linearIndex(length() - 1); - } catch (Exception e) { - return false; - } - return true; + if (dimension.length == 1) { + if (dimension[0] == 0 && isColumnVector()) { + return this.transpose(); + } else if (dimension[0] == 1 && isRowVector()) { + return this; + } } - protected INDArray create(DataBuffer data, int[] shape, long offset) { - return Nd4j.create(data, shape, offset); + Pair tadInfo = Nd4j.getExecutioner().getTADManager() + .getTADOnlyShapeInfo(this, dimension); + DataBuffer shapeInfo = tadInfo.getFirst(); + val jShapeInfo = shapeInfo.asLong(); + val shape = Shape.shape(jShapeInfo); + val stride = Shape.stride(jShapeInfo); + long offset = offset() + tadInfo.getSecond().getLong(index); + val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2); + char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3); + val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); + return toTad; + } + + private void setShapeInformation(Pair shapeInfo) { + this.shapeInformation = shapeInfo.getFirst(); + this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond()); + } + + + private INDArray doTad(int index, int... dimension) { + if (dimension == null || dimension.length == 0) { + throw new IllegalArgumentException( + "Invalid input: dimensions not specified (null or length 0)"); + } + + if (dimension.length >= rank()) { + return this; + } + for (int i = 0; i < dimension.length; i++) { + if (dimension[i] < 0) { + dimension[i] += rank(); + } + } + + if (dimension.length > 1) { + Arrays.sort(dimension); + } + + long tads = tensorsAlongDimension(dimension); + if (index >= tads) { + throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); + } + + if (dimension.length == 1) { + if (dimension[0] == 0 && isColumnVector()) { + return this.transpose(); + } else if (dimension[0] == 1 && isRowVector()) { + return this; + } } - @Override - public int elementWiseStride() { - return Shape.elementWiseStride(shapeInfoDataBuffer()); - } + long[] tensorShape = ArrayUtil.keep(shape(), dimension); + int[] reverseDimensions = ArrayUtil.reverseCopy(dimension); + int[] remove = ArrayUtil.removeIndex(ArrayUtil.range(0, rank()), dimension); + int[] newPermuteDims = Ints.concat(remove, reverseDimensions); + int[] finalPermuteDims = tadFinalPermuteDimensions[dimension.length]; - @Override - public long tensorsAlongDimension(int... dimension) { - if (dimension == null || dimension.length == 0) - throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); - if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) - return 1; - for (int i = 0; i < dimension.length; i++) - if (dimension[i] < 0) - dimension[i] += rank(); - long[] tensorShape = ArrayUtil.keep(shape(), dimension); - long len = ArrayUtil.prodLong(tensorShape); - if (len == 0) - throw new IllegalStateException("Illegal length found after removing index"); - long length = length(); - if (length / len >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Tensors along dimension can not be >= Integer.MAX_VALUE"); - return length / len; - } + INDArray permuted = permute(newPermuteDims); + long sliceIdx = NDArrayMath.sliceOffsetForTensor(index, permuted, tensorShape); - @Override - public INDArray tensorAlongDimension(long index, int... dimension) { - if (dimension == null || dimension.length == 0) - throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); - - Preconditions.checkArgument(!this.isEmpty(), "tensorAlongDimension(...) can't be used on empty tensors"); - - if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) - return this; - for (int i = 0; i < dimension.length; i++) - if (dimension[i] < 0) - dimension[i] += rank(); - - //dedup - if (dimension.length > 1) - dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension)))); - - if (dimension.length > 1) { - Arrays.sort(dimension); - } - - long tads = tensorsAlongDimension(dimension); - if (index >= tads) - throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); - - - if (dimension.length == 1) { - if (dimension[0] == 0 && isColumnVector()) { - return this.transpose(); - } else if (dimension[0] == 1 && isRowVector()) { - return this; - } - } - - - Pair tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); - DataBuffer shapeInfo = tadInfo.getFirst(); - val jShapeInfo = shapeInfo.asLong(); - val shape = Shape.shape(jShapeInfo); - val stride = Shape.stride(jShapeInfo); - long offset = offset() + tadInfo.getSecond().getLong(index); - val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2); - char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3); - val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); - return toTad; - } - - private void setShapeInformation(Pair shapeInfo) { - this.shapeInformation = shapeInfo.getFirst(); - this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond()); - } - - - private INDArray doTad(int index, int... dimension) { - if (dimension == null || dimension.length == 0) - throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); - - if (dimension.length >= rank()) - return this; - for (int i = 0; i < dimension.length; i++) - if (dimension[i] < 0) - dimension[i] += rank(); - - if (dimension.length > 1) - Arrays.sort(dimension); - - long tads = tensorsAlongDimension(dimension); - if (index >= tads) - throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); - - - if (dimension.length == 1) { - if (dimension[0] == 0 && isColumnVector()) { - return this.transpose(); - } else if (dimension[0] == 1 && isRowVector()) { - return this; - } - } - - - long[] tensorShape = ArrayUtil.keep(shape(), dimension); - int[] reverseDimensions = ArrayUtil.reverseCopy(dimension); - int[] remove = ArrayUtil.removeIndex(ArrayUtil.range(0, rank()), dimension); - int[] newPermuteDims = Ints.concat(remove, reverseDimensions); - int[] finalPermuteDims = tadFinalPermuteDimensions[dimension.length]; - - INDArray permuted = permute(newPermuteDims); - long sliceIdx = NDArrayMath.sliceOffsetForTensor(index, permuted, tensorShape); - - INDArray ret2 = permuted.slice(sliceIdx); - if (dimension.length == tensorShape.length && ArrayUtil.prodLong(tensorShape) == ret2.length()) { - if (dimension.length == 1 && ret2.isRowVector()) - return ret2; - if (finalPermuteDims.length != ret2.rank()) { - finalPermuteDims = new int[ret2.rank()]; - int count = 0; - for (int i = finalPermuteDims.length - 1; i >= 0; i--) - finalPermuteDims[count++] = i; - } - return ret2.permutei(finalPermuteDims); - } - - - int length = ArrayUtil.prod(tensorShape); - int tensorLength = ArrayUtil.prod(tensorShape); - long offset = (long) index * tensorLength / NDArrayMath.lengthPerSlice(ret2); - - if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) { - if (offset > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - ret2 = ret2.slice((int) offset); - if (dimension.length == 1 && ret2.isRowVectorOrScalar()) - return ret2; - return ret2.permutei(finalPermuteDims); - } - - else if (length == NDArrayMath.lengthPerSlice(ret2)) { - offset -= ret2.slices() * (offset / ret2.slices()); - - if (offset > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - ret2 = ret2.slice((int) offset); - if (dimension.length == 1 && ret2.isRowVectorOrScalar()) - return ret2; - return ret2.permutei(finalPermuteDims); - } - - while (ret2.length() > length) { - sliceIdx = NDArrayMath.sliceOffsetForTensor(index, ret2, tensorShape); - sliceIdx -= ret2.slices() * (sliceIdx / ret2.slices()); - ret2 = ret2.slice(sliceIdx); - } - - if (dimension.length == 1 && ret2.isRowVectorOrScalar()) + INDArray ret2 = permuted.slice(sliceIdx); + if (dimension.length == tensorShape.length + && ArrayUtil.prodLong(tensorShape) == ret2.length()) { + if (dimension.length == 1 && ret2.isRowVector()) { return ret2; - - return ret2.permutei(finalPermuteDims); + } + if (finalPermuteDims.length != ret2.rank()) { + finalPermuteDims = new int[ret2.rank()]; + int count = 0; + for (int i = finalPermuteDims.length - 1; i >= 0; i--) { + finalPermuteDims[count++] = i; + } + } + return ret2.permutei(finalPermuteDims); } - @Override - public long vectorsAlongDimension(int dimension) { - if (dimension == 0 && isVector() || isRowVectorOrScalar()) - return 1; - if (size(dimension) == 1 && !isVector()) { - for (int i = dimension; i < rank(); i++) { - if (size(i) != 1) - return vectorsAlongDimension(i); - } + int length = ArrayUtil.prod(tensorShape); + int tensorLength = ArrayUtil.prod(tensorShape); + long offset = (long) index * tensorLength / NDArrayMath.lengthPerSlice(ret2); - return length(); + if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) { + if (offset > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + ret2 = ret2.slice((int) offset); + if (dimension.length == 1 && ret2.isRowVectorOrScalar()) { + return ret2; + } + return ret2.permutei(finalPermuteDims); + } else if (length == NDArrayMath.lengthPerSlice(ret2)) { + offset -= ret2.slices() * (offset / ret2.slices()); - } else if (size(0) == 1 && !isVectorOrScalar()) { - int realDimension = rank() - getLeadingOnes(); - long length = length(); - if (length / size(realDimension) >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); - return length / size(realDimension); + if (offset > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + ret2 = ret2.slice((int) offset); + if (dimension.length == 1 && ret2.isRowVectorOrScalar()) { + return ret2; + } + return ret2.permutei(finalPermuteDims); + } + + while (ret2.length() > length) { + sliceIdx = NDArrayMath.sliceOffsetForTensor(index, ret2, tensorShape); + sliceIdx -= ret2.slices() * (sliceIdx / ret2.slices()); + ret2 = ret2.slice(sliceIdx); + } + + if (dimension.length == 1 && ret2.isRowVectorOrScalar()) { + return ret2; + } + + return ret2.permutei(finalPermuteDims); + } + + @Override + public long vectorsAlongDimension(int dimension) { + if (dimension == 0 && isVector() || isRowVectorOrScalar()) { + return 1; + } + if (size(dimension) == 1 && !isVector()) { + for (int i = dimension; i < rank(); i++) { + if (size(i) != 1) { + return vectorsAlongDimension(i); + } + } + + return length(); + + } else if (size(0) == 1 && !isVectorOrScalar()) { + int realDimension = rank() - getLeadingOnes(); + long length = length(); + if (length / size(realDimension) >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Vectors along dimension can not be >= Integer.MAX_VALUE"); + } + return length / size(realDimension); + } + + long length = length(); + + if (dimension >= jvmShapeInfo.rank) { + if (length / size(jvmShapeInfo.rank - 1) >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Vectors along dimension can not be >= Integer.MAX_VALUE"); + } + return (int) (length / size(jvmShapeInfo.rank - 1)); + } + if (length / size(dimension) >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Vectors along dimension can not be >= Integer.MAX_VALUE"); + } + return length / size(dimension); + } + + @Override + public INDArray vectorAlongDimension(int index, int dimension) { + if (dimension < 0) { + dimension = jvmShapeInfo.getRank() + dimension; + } + + //return the whole thing + if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2 + || rank() > 2 && dimension == 0 && size(dimension) == 1) { + return this; + } + + return tensorAlongDimension(index, dimension); + } + + @Override + public void setOrder(char order) { + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), + isEmpty())); + } + + @Override + public void setShapeAndStride(int[] shape, int[] stride) { + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, + ordering(), this.dataType(), false)); + } + + @Override + public INDArray cumsumi(int dimension) { + validateNumericalArray("cumsumi", true); + + if (isScalar() || isEmpty()) { + return this; + } + + if (isVector()) { + double s = 0.0; + for (int i = 0; i < length(); i++) { + s += getDouble(i); + putScalar(i, s); + } + } else if (dimension == Integer.MAX_VALUE) { + INDArray flattened = ravel(); + double prevVal = flattened.getDouble(0); + for (int i = 1; i < flattened.length(); i++) { + double d = prevVal + flattened.getDouble(i); + flattened.putScalar(i, d); + prevVal = d; + } + + return flattened; + } else { + for (int i = 0; i < vectorsAlongDimension(dimension); i++) { + INDArray vec = vectorAlongDimension(i, dimension); + vec.cumsumi(0); + + } + } + + return this; + } + + @Override + public Number normmaxNumber() { + return normmax(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number norm2Number() { + return norm2(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number norm1Number() { + return norm1(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number stdNumber() { + return std(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number prodNumber() { + if (isScalar()) { + return getNumber(0); + } + return prod(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number meanNumber() { + validateNumericalArray("meanNumber", false); + if (isScalar()) { + return getNumber(0); + } + return mean(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number ameanNumber() { + return amean(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number varNumber() { + return var(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number maxNumber() { + if (isScalar()) { + return getNumber(0); + } + return max(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number amaxNumber() { + return amax(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number minNumber() { + if (isScalar()) { + return getNumber(0); + } + return min(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number aminNumber() { + return amin(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number scan(Condition condition) { + MatchCondition op = new MatchCondition(this, condition); + return Nd4j.getExecutioner().exec(op).getDouble(0); + } + + @Override + public Number sumNumber() { + validateNumericalArray("sum", false); + if (isScalar()) { + return getNumber(0); + } + val scalar = sum(Integer.MAX_VALUE); + Nd4j.getExecutioner().commit(); + return scalar.getDouble(0); + } + + @Override + public Number entropyNumber() { + return entropy(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number shannonEntropyNumber() { + return shannonEntropy(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number logEntropyNumber() { + return logEntropy(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public INDArray cumsum(int dimension) { + validateNumericalArray("cumsum", true); + return dup().cumsumi(dimension); + } + + @Override + public INDArray assign(final INDArray arr) { + Preconditions.checkState( + (this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) + || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()), + "Cannot assign arrays: arrays must both be scalars, both vectors, or shapes must be equal other than size 1 dimensions. Attempting to do x.assign(y)" + + + " with x.shape=%ndShape and y.shape=%ndShape", this, arr); + + Preconditions.checkArgument(this.length() == arr.length(), + "Length of both arrays must be equal"); + + Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this)); + return this; + } + + @Override + public INDArray putScalar(long i, double value) { + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + if (i < 0) { + i += rank(); + } + + // TODO: i'm not sure that rank == 1 has fair shortcut here + if (isScalar()) { + autoProcessScalarCall(); + data.put(i, value); + return this; + } else if (rank() == 1) { + data.put(i * stride(0), value); + return this; + } + + // we cant raise rank here, if original rank is 1 + if (isRowVector() && rank() == 2) { + return putScalar(0, i, value); + } else if (isColumnVector() && rank() == 2) { + return putScalar(i, 0, value); + } + long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); + return putScalar(indexes, value); + } + + @Override + public INDArray putScalar(long i, float value) { + return putScalar(i, (double) value); + } + + @Override + public INDArray putScalar(long i, int value) { + return putScalar(i, (double) value); + } + + @Override + public INDArray putScalar(int[] indexes, double value) { + Nd4j.getCompressor().autoDecompress(this); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += this.size(i); + } + } + + if (indexes.length == 1) { + return putScalar(indexes[0], value); + } else if (indexes.length == 2) { + return putScalar(indexes[0], indexes[1], value); + } else if (indexes.length == 3) { + return putScalar(indexes[0], indexes[1], indexes[2], value); + } else if (indexes.length == 4) { + return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); + } else { + autoProcessScalarCall(); + long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + data.put(offset, value); + } + return this; + } + + @Override + public INDArray putScalar(long[] indexes, double value) { + Nd4j.getCompressor().autoDecompress(this); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += size(i); + } + } + + if (indexes.length == 1) { + return putScalar(indexes[0], value); + } else if (indexes.length == 2) { + return putScalar(indexes[0], indexes[1], value); + } else if (indexes.length == 3) { + return putScalar(indexes[0], indexes[1], indexes[2], value); + } else if (indexes.length == 4) { + return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); + } else { + autoProcessScalarCall(); + long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + data.put(offset, value); + } + return this; + } + + @Override + public INDArray putScalar(long[] indexes, float value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray putScalar(long row, long col, double value) { + Nd4j.getCompressor().autoDecompress(this); + autoProcessScalarCall(); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + if (rank() > 2) { + throw new IllegalStateException( + "Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray"); + } + long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, row, col); + data.put(offset, value); + return this; + } + + @Override + public INDArray putScalar(long dim0, long dim1, long dim2, double value) { + Nd4j.getCompressor().autoDecompress(this); + autoProcessScalarCall(); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + if (rank() != 3) { + throw new IllegalStateException( + "Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray"); + } + long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2); + long size_0 = jvmShapeInfo.javaShapeInformation[1]; + long size_1 = jvmShapeInfo.javaShapeInformation[1 + 1]; + long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2]; + + if (size_0 != 1) { + offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 3]; + } + if (size_1 != 1) { + offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3]; + } + if (size_2 != 1) { + offset += dim2 * jvmShapeInfo.javaShapeInformation[1 + 2 + 3]; + } + + data.put(offset, value); + return this; + } + + @Override + public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) { + Nd4j.getCompressor().autoDecompress(this); + autoProcessScalarCall(); + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + if (rank() != 4) { + throw new IllegalStateException( + "Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray"); + } + long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, dim0, dim1, dim2, dim3); + data.put(offset, value); + return this; + } + + @Override + public INDArray putScalar(int[] indexes, float value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray putScalar(int[] indexes, int value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray putScalar(long[] indexes, int value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray eps(Number other) { + validateNumericalArray("eps", true); + return Nd4j.getExecutioner().exec( + new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), + other)); + } + + @Override + public INDArray eps(INDArray other) { + validateNumericalArray("eps", true); + return Nd4j.getExecutioner().exec(new Eps(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); + } + + @Override + public INDArray lt(Number other) { + validateNumericalArray("less than (lt)", false); + return Nd4j.getExecutioner().exec(new ScalarLessThan(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray lte(Number other) { + validateNumericalArray("less than or equals (lte)", false); + return Nd4j.getExecutioner().exec(new ScalarLessThanOrEqual(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray eq(Number other) { + Preconditions.checkArgument( + dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, + "Scalar equality on boolean arrays can only be applied with values 0 or 1: got value %s", + other); + return Nd4j.getExecutioner().exec(new ScalarEquals(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray gt(Number other) { + validateNumericalArray("greater than (gt)", false); + return Nd4j.getExecutioner().exec(new ScalarGreaterThan(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray gte(Number other) { + validateNumericalArray("greater than or equals (gte)", false); + return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray lt(INDArray other) { + validateNumericalArray("less than (lt)", false); + if (Shape.shapeEquals(this.shape(), other.shape())) { + return Nd4j.getExecutioner().exec(new LessThan(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{ + Nd4j.createUninitialized(DataType.BOOL, + Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; + } else { + throw new IllegalArgumentException("Shapes must be broadcastable"); + } + } + + @Override + public INDArray neq(Number other) { + Preconditions.checkArgument( + dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, + "Scalar non-equality on boolean arrays can only be applied with values 0 or 1: got value %s", + other); + Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); + return Nd4j.getExecutioner().exec(new ScalarNotEquals(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray neq(INDArray other) { + Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); + return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } + + @Override + public INDArray eq(INDArray other) { + if (Shape.shapeEquals(this.shape(), other.shape())) { + return Nd4j.getExecutioner().exec(new EqualTo(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{ + Nd4j.createUninitialized(DataType.BOOL, + Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; + } else { + throw new IllegalArgumentException("Shapes must be broadcastable"); + } + } + + @Override + public INDArray gt(INDArray other) { + validateNumericalArray("greater than (gt)", false); + if (Shape.shapeEquals(this.shape(), other.shape())) { + return Nd4j.getExecutioner().exec(new GreaterThan(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{ + Nd4j.createUninitialized(DataType.BOOL, + Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; + } else { + throw new IllegalArgumentException("Shapes must be broadcastable"); + } + } + + @Override + public INDArray isInfinite() { + validateNumericalArray("isInfinite", true); + if (isEmpty()) { + return Nd4j.empty(DataType.BOOL); + } + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), + Conditions.isInfinite())); + } + + @Override + public INDArray isNaN() { + validateNumericalArray("isNaN", true); + if (isEmpty()) { + return Nd4j.empty(DataType.BOOL); + } + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), + Conditions.isNan())); + } + + @Override + public INDArray neg() { + validateNumericalArray("negative (neg)", true); + if (isEmpty()) { + return this; + } + return Nd4j.getExecutioner().exec(new Negative(this, + Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()))); + } + + @Override + public INDArray negi() { + validateNumericalArray("negative (negi)", true); + if (isEmpty()) { + return this; + } + Nd4j.getExecutioner().exec(new Negative(this)); + return this; + } + + @Override + public INDArray rdiv(Number n, INDArray result) { + return rdivi(n, result); + } + + @Override + public INDArray rdivi(Number n, INDArray result) { + validateNumericalArray("rdivi", false); + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + Nd4j.getExecutioner().exec(new ScalarReverseDivision(this, null, result, n)); + return result; + } + + @Override + public INDArray rsub(Number n, INDArray result) { + return rsubi(n, result); + } + + @Override + public INDArray rsubi(Number n, INDArray result) { + validateNumericalArray("rsubi", false); + + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + + Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(this, result, n)); + return result; + } + + @Override + public INDArray div(Number n, INDArray result) { + return divi(n, result); + } + + @Override + public INDArray divi(Number n, INDArray result) { + validateNumericalArray("divi", false); + + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + Nd4j.getExecutioner().exec(new ScalarDivision(this, null, result, n)); + return result; + } + + @Override + public INDArray mul(Number n, INDArray result) { + return muli(n, result); + } + + @Override + public INDArray muli(Number n, INDArray result) { + validateNumericalArray("muli", false); + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + Nd4j.getExecutioner().exec(new ScalarMultiplication(this, null, result, n)); + return result; + } + + @Override + public INDArray sub(Number n, INDArray result) { + return subi(n, result); + } + + @Override + public INDArray subi(Number n, INDArray result) { + validateNumericalArray("subi", false); + + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + + Nd4j.getExecutioner().exec(new ScalarSubtraction(this, null, result, n)); + return result; + } + + @Override + public INDArray add(Number n, INDArray result) { + return addi(n, result); + } + + @Override + public INDArray addi(Number n, INDArray result) { + validateNumericalArray("addi", false); + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + + Nd4j.getExecutioner().exec(new ScalarAdd(this, null, result, n)); + return result; + } + + @Override + public INDArray getScalar(long row, long column) { + return getScalar(new long[]{row, column}); + } + + @Override + public INDArray dup() { + return dup(Nd4j.order()); + } + + @Override + public INDArray dup(char order) { + WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray"); + if (this.isCompressed() && this.ordering() == order) { + INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer()); + ret.markAsCompressed(true); + return ret; + } + if (isEmpty()) { + return this; + } + + Nd4j.getCompressor().autoDecompress(this); + + // fixme: eventually it would be nice to have this in native code + if (isS()) { + val list = new ArrayList(); + for (int e = 0; e < this.length(); e++) { + list.add(this.getString(e)); } - long length = length(); + return Nd4j.create(list, this.shape(), this.ordering()); + } - if (dimension >= jvmShapeInfo.rank) { - if (length / size(jvmShapeInfo.rank - 1) >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); - return (int) (length / size(jvmShapeInfo.rank - 1)); + val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order); + z.assign(this); + return z; + } + + @Override + public int getInt(int... indices) { + return (int) getDouble(indices); + } + + @Override + public long getLong(long index) { + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + if (index >= length()) { + throw new IllegalArgumentException( + "Unable to get linear index " + index + ": values is greater than length (" + length() + + ")"); + } + + autoProcessScalarCall(); + + if (index == 0) { + return data().getLong(index); + } + + long[] dimensions = + ordering() == 'c' ? Shape.ind2subC(this, index) : Shape.ind2sub(this, index); + Shape.assertShapeLessThan(dimensions, shape()); + return getLong(dimensions); + } + + @Override + public long getLong(long... indices) { + if (isScalar()) { + return data().getLong(0); + } + return Shape.getLong(this, indices); + } + + @Override + public double getDouble(int... indices) { + autoProcessScalarCall(); + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + for (int i = 0; i < indices.length; i++) { + if (indices[i] < 0) { + indices[i] += rank(); } - if (length / size(dimension) >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); - return length / size(dimension); } - - @Override - public INDArray vectorAlongDimension(int index, int dimension) { - if (dimension < 0) { - dimension = jvmShapeInfo.getRank() + dimension; + if (indices.length == 1) { + if (rank() == 1) { + return Shape.getDouble(this, indices[0]); + } else if (isRowVector()) { + return Shape.getDouble(this, 0, indices[0]); + } else if (isColumnVector()) { + return Shape.getDouble(this, indices[0], 0); + } else if ((isScalar() || length() == 1) && indices[0] == 0) { + return data().getDouble(0); } + } + return Shape.getDouble(this, indices); + } - //return the whole thing - if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2 - || rank() > 2 && dimension == 0 && size(dimension) == 1) { - return this; + @Override + public double getDouble(long... indices) { + autoProcessScalarCall(); + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + for (int i = 0; i < indices.length; i++) { + if (indices[i] < 0) { + indices[i] += rank(); } - - return tensorAlongDimension(index, dimension); } - - @Override - public void setOrder(char order) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty())); - } - - @Override - public void setShapeAndStride(int[] shape, int[] stride) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false)); - } - - @Override - public INDArray cumsumi(int dimension) { - validateNumericalArray("cumsumi", true); - - if(isScalar() || isEmpty()) - return this; - - if (isVector()) { - double s = 0.0; - for (int i = 0; i < length(); i++) { - s += getDouble(i); - putScalar(i, s); - } - } else if (dimension == Integer.MAX_VALUE) { - INDArray flattened = ravel(); - double prevVal = flattened.getDouble(0); - for (int i = 1; i < flattened.length(); i++) { - double d = prevVal + flattened.getDouble(i); - flattened.putScalar(i, d); - prevVal = d; - } - - return flattened; + if (indices.length == 1) { + if (rank() == 1) { + return Shape.getDouble(this, indices[0]); + } else if (isRowVector()) { + return Shape.getDouble(this, 0, indices[0]); + } else if (isColumnVector()) { + return Shape.getDouble(this, indices[0], 0); + } else if (isScalar() && indices[0] == 0) { + return data().getDouble(0); } else { - for (int i = 0; i < vectorsAlongDimension(dimension); i++) { - INDArray vec = vectorAlongDimension(i, dimension); - vec.cumsumi(0); - - } - } - - return this; - } - - @Override - public Number normmaxNumber() { - return normmax(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number norm2Number() { - return norm2(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number norm1Number() { - return norm1(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number stdNumber() { - return std(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number prodNumber() { - if(isScalar()) - return getNumber(0); - return prod(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number meanNumber() { - validateNumericalArray("meanNumber", false); - if(isScalar()) - return getNumber(0); - return mean(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number ameanNumber() { - return amean(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number varNumber() { - return var(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number maxNumber() { - if(isScalar()) - return getNumber(0); - return max(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number amaxNumber() { - return amax(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number minNumber() { - if(isScalar()) - return getNumber(0); - return min(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number aminNumber() { - return amin(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number scan(Condition condition) { - MatchCondition op = new MatchCondition(this, condition); - return Nd4j.getExecutioner().exec(op).getDouble(0); - } - - @Override - public Number sumNumber() { - validateNumericalArray("sum", false); - if(isScalar()) - return getNumber(0); - val scalar = sum(Integer.MAX_VALUE); - Nd4j.getExecutioner().commit(); - return scalar.getDouble(0); - } - - @Override - public Number entropyNumber() { - return entropy(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number shannonEntropyNumber() { - return shannonEntropy(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number logEntropyNumber() { - return logEntropy(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public INDArray cumsum(int dimension) { - validateNumericalArray("cumsum", true); - return dup().cumsumi(dimension); - } - - @Override - public INDArray assign(final INDArray arr) { - Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()), - "Cannot assign arrays: arrays must both be scalars, both vectors, or shapes must be equal other than size 1 dimensions. Attempting to do x.assign(y)" + - " with x.shape=%ndShape and y.shape=%ndShape", this, arr ); - - Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal"); - - Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this)); - return this; - } - - @Override - public INDArray putScalar(long i, double value) { - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - if (i < 0) - i += rank(); - - // TODO: i'm not sure that rank == 1 has fair shortcut here - if (isScalar()) { - autoProcessScalarCall(); - data.put(i, value); - return this; - } else if (rank() == 1) { - data.put(i * stride(0), value); - return this; - } - - // we cant raise rank here, if original rank is 1 - if (isRowVector() && rank() == 2) { - return putScalar(0, i, value); - } else if (isColumnVector() && rank() == 2) { - return putScalar(i, 0, value); - } - long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); - return putScalar(indexes, value); - } - - @Override - public INDArray putScalar(long i, float value) { - return putScalar(i, (double) value); - } - - @Override - public INDArray putScalar(long i, int value) { - return putScalar(i, (double) value); - } - - @Override - public INDArray putScalar(int[] indexes, double value) { - Nd4j.getCompressor().autoDecompress(this); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += this.size(i); - } - - if (indexes.length == 1) { - return putScalar(indexes[0], value); - } else if (indexes.length == 2) { - return putScalar(indexes[0], indexes[1], value); - } else if (indexes.length == 3) { - return putScalar(indexes[0], indexes[1], indexes[2], value); - } else if (indexes.length == 4) { - return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); - } else { - autoProcessScalarCall(); - long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - data.put(offset, value); - } - return this; - } - - @Override - public INDArray putScalar(long[] indexes, double value) { - Nd4j.getCompressor().autoDecompress(this); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += size(i); - } - - if (indexes.length == 1) { - return putScalar(indexes[0], value); - } else if (indexes.length == 2) { - return putScalar(indexes[0], indexes[1], value); - } else if (indexes.length == 3) { - return putScalar(indexes[0], indexes[1], indexes[2], value); - } else if (indexes.length == 4) { - return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); - } else { - autoProcessScalarCall(); - long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - data.put(offset, value); - } - return this; - } - - @Override - public INDArray putScalar(long[] indexes, float value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray putScalar(long row, long col, double value) { - Nd4j.getCompressor().autoDecompress(this); - autoProcessScalarCall(); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - if (rank() > 2) - throw new IllegalStateException("Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray"); - long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, row, col); - data.put(offset, value); - return this; - } - - @Override - public INDArray putScalar(long dim0, long dim1, long dim2, double value) { - Nd4j.getCompressor().autoDecompress(this); - autoProcessScalarCall(); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - if (rank() != 3) throw new IllegalStateException( - "Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray"); - long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2); - long size_0 = jvmShapeInfo.javaShapeInformation[1]; - long size_1 = jvmShapeInfo.javaShapeInformation[1 + 1]; - long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2]; - - if (size_0 != 1) - offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 3]; - if (size_1 != 1) - offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3]; - if (size_2 != 1) - offset += dim2 * jvmShapeInfo.javaShapeInformation[1 + 2 + 3]; - - data.put(offset, value); - return this; - } - - @Override - public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) { - Nd4j.getCompressor().autoDecompress(this); - autoProcessScalarCall(); - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - if (rank() != 4) - throw new IllegalStateException( - "Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray"); - long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, dim0, dim1, dim2, dim3); - data.put(offset, value); - return this; - } - - @Override - public INDArray putScalar(int[] indexes, float value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray putScalar(int[] indexes, int value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray putScalar(long[] indexes, int value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray eps(Number other) { - validateNumericalArray("eps", true); - return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray eps(INDArray other) { - validateNumericalArray("eps", true); - return Nd4j.getExecutioner().exec(new Eps(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); - } - - @Override - public INDArray lt(Number other) { - validateNumericalArray("less than (lt)", false); - return Nd4j.getExecutioner().exec(new ScalarLessThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray lte(Number other) { - validateNumericalArray("less than or equals (lte)", false); - return Nd4j.getExecutioner().exec(new ScalarLessThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray eq(Number other) { - Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar equality on boolean arrays can only be applied with values 0 or 1: got value %s",other); - return Nd4j.getExecutioner().exec(new ScalarEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray gt(Number other) { - validateNumericalArray("greater than (gt)", false); - return Nd4j.getExecutioner().exec(new ScalarGreaterThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray gte(Number other) { - validateNumericalArray("greater than or equals (gte)", false); - return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray lt(INDArray other) { - validateNumericalArray("less than (lt)", false); - if (Shape.shapeEquals(this.shape(), other.shape())) { - return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; - } else - throw new IllegalArgumentException("Shapes must be broadcastable"); - } - - @Override - public INDArray neq(Number other) { - Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar non-equality on boolean arrays can only be applied with values 0 or 1: got value %s",other); - Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); - return Nd4j.getExecutioner().exec(new ScalarNotEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray neq(INDArray other) { - Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); - return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } - - @Override - public INDArray eq(INDArray other) { - if (Shape.shapeEquals(this.shape(), other.shape())) { - return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; - } else - throw new IllegalArgumentException("Shapes must be broadcastable"); - } - - @Override - public INDArray gt(INDArray other) { - validateNumericalArray("greater than (gt)", false); - if (Shape.shapeEquals(this.shape(), other.shape())) { - return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; - } else - throw new IllegalArgumentException("Shapes must be broadcastable"); - } - - @Override - public INDArray isInfinite(){ - validateNumericalArray("isInfinite", true); - if(isEmpty()) - return Nd4j.empty(DataType.BOOL); - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isInfinite())); - } - - @Override - public INDArray isNaN(){ - validateNumericalArray("isNaN", true); - if(isEmpty()) - return Nd4j.empty(DataType.BOOL); - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan())); - } - - @Override - public INDArray neg() { - validateNumericalArray("negative (neg)", true); - if(isEmpty()) - return this; - return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()))); - } - - @Override - public INDArray negi() { - validateNumericalArray("negative (negi)", true); - if(isEmpty()) - return this; - Nd4j.getExecutioner().exec(new Negative(this)); - return this; - } - - @Override - public INDArray rdiv(Number n, INDArray result) { - return rdivi(n, result); - } - - @Override - public INDArray rdivi(Number n, INDArray result) { - validateNumericalArray("rdivi", false); - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - Nd4j.getExecutioner().exec(new ScalarReverseDivision(this, null, result, n)); - return result; - } - - @Override - public INDArray rsub(Number n, INDArray result) { - return rsubi(n, result); - } - - @Override - public INDArray rsubi(Number n, INDArray result) { - validateNumericalArray("rsubi", false); - - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - - Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(this, result, n)); - return result; - } - - @Override - public INDArray div(Number n, INDArray result) { - return divi(n, result); - } - - @Override - public INDArray divi(Number n, INDArray result) { - validateNumericalArray("divi", false); - - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - Nd4j.getExecutioner().exec(new ScalarDivision(this, null, result, n)); - return result; - } - - @Override - public INDArray mul(Number n, INDArray result) { - return muli(n, result); - } - - @Override - public INDArray muli(Number n, INDArray result) { - validateNumericalArray("muli", false); - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - Nd4j.getExecutioner().exec(new ScalarMultiplication(this, null, result, n)); - return result; - } - - @Override - public INDArray sub(Number n, INDArray result) { - return subi(n, result); - } - - @Override - public INDArray subi(Number n, INDArray result) { - validateNumericalArray("subi", false); - - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - - Nd4j.getExecutioner().exec(new ScalarSubtraction(this, null, result, n)); - return result; - } - - @Override - public INDArray add(Number n, INDArray result) { - return addi(n, result); - } - - @Override - public INDArray addi(Number n, INDArray result) { - validateNumericalArray("addi", false); - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - - Nd4j.getExecutioner().exec(new ScalarAdd(this, null, result, n)); - return result; - } - - @Override - public INDArray getScalar(long row, long column) { - return getScalar(new long[] {row, column}); - } - - @Override - public INDArray dup() { - return dup(Nd4j.order()); - } - - @Override - public INDArray dup(char order) { - WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray"); - if (this.isCompressed() && this.ordering() == order) { - INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer()); - ret.markAsCompressed(true); - return ret; + "Indexes length must be > 1 for non vectors and scalars"); } - if(isEmpty()) - return this; + } + return Shape.getDouble(this, indices); + } - Nd4j.getCompressor().autoDecompress(this); + @Override + public float getFloat(int... indices) { + return (float) getDouble(indices); + } - // fixme: eventually it would be nice to have this in native code - if (isS()) { - val list = new ArrayList(); - for (int e = 0; e < this.length(); e++) - list.add(this.getString(e)); + @Override + public float getFloat(long... indices) { + return (float) getDouble(indices); + } - return Nd4j.create(list, this.shape(), this.ordering()); + @Override + public boolean isScalar() { + if (isEmpty()) { + return false; + } + + if (jvmShapeInfo.rank == 0) { + return true; + } else if (jvmShapeInfo.rank > 2) { + return false; + } else if (jvmShapeInfo.rank == 1) { + return shape()[0] == 1; + } else if (jvmShapeInfo.rank == 2) { + return shape()[0] == 1 && shape()[1] == 1 || length() == 1; + } else { + return false; + } + + } + + @Override + public INDArray put(int[] indices, INDArray element) { + Nd4j.getCompressor().autoDecompress(this); + if (!element.isScalar()) { + throw new IllegalArgumentException("Unable to insert anything but a scalar"); + } + if (isRowVector() && indices[0] == 0 && indices.length == 2) { + int ix = 0; + for (int i = 1; i < indices.length; i++) { + ix += indices[i] * stride(i); } - - val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order); - z.assign(this); - return z; - } - - @Override - public int getInt(int... indices) { - return (int) getDouble(indices); - } - - @Override - public long getLong(long index) { - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); - - if (index >= length()) { - throw new IllegalArgumentException("Unable to get linear index " + index + ": values is greater than length (" + length() + ")"); + if (ix >= data.length()) { + throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); } - - autoProcessScalarCall(); - - if (index == 0) - return data().getLong(index); - - long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, index) : Shape.ind2sub(this, index); - Shape.assertShapeLessThan(dimensions, shape()); - return getLong(dimensions); - } - - @Override - public long getLong(long... indices) { - if(isScalar()) - return data().getLong(0); - return Shape.getLong(this, indices); - } - - @Override - public double getDouble(int... indices) { - autoProcessScalarCall(); - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); - + data.put(ix, element.getDouble(0)); + } else { + int ix = 0; for (int i = 0; i < indices.length; i++) { - if (indices[i] < 0) - indices[i] += rank(); - } - if (indices.length == 1) { - if (rank() == 1) - return Shape.getDouble(this, indices[0]); - else if (isRowVector()) - return Shape.getDouble(this, 0, indices[0]); - else if (isColumnVector()) - return Shape.getDouble(this, indices[0], 0); - else if ((isScalar() || length() == 1) && indices[0] == 0) - return data().getDouble(0); - } - return Shape.getDouble(this, indices); - } - - @Override - public double getDouble(long... indices) { - autoProcessScalarCall(); - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); - - for (int i = 0; i < indices.length; i++) { - if (indices[i] < 0) - indices[i] += rank(); - } - if (indices.length == 1) { - if (rank() == 1) - return Shape.getDouble(this, indices[0]); - else if (isRowVector()) - return Shape.getDouble(this, 0, indices[0]); - else if (isColumnVector()) - return Shape.getDouble(this, indices[0], 0); - else if (isScalar() && indices[0] == 0) - return data().getDouble(0); - else - throw new IllegalStateException("Indexes length must be > 1 for non vectors and scalars"); - } - return Shape.getDouble(this, indices); - } - - @Override - public float getFloat(int... indices) { - return (float) getDouble(indices); - } - - @Override - public float getFloat(long... indices) { - return (float) getDouble(indices); - } - - @Override - public boolean isScalar() { - if (isEmpty()) - return false; - - if (jvmShapeInfo.rank == 0) { - return true; - } else if (jvmShapeInfo.rank > 2) { - return false; - } else if (jvmShapeInfo.rank == 1) { - return shape()[0] == 1; - } else if (jvmShapeInfo.rank == 2) { - return shape()[0] == 1 && shape()[1] == 1 || length() == 1; - } - - else - return false; - - } - - @Override - public INDArray put(int[] indices, INDArray element) { - Nd4j.getCompressor().autoDecompress(this); - if (!element.isScalar()) - throw new IllegalArgumentException("Unable to insert anything but a scalar"); - if (isRowVector() && indices[0] == 0 && indices.length == 2) { - int ix = 0; - for (int i = 1; i < indices.length; i++) + if (size(i) != 1) { ix += indices[i] * stride(i); - if (ix >= data.length()) - throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); - data.put(ix, element.getDouble(0)); - } else { - int ix = 0; - for (int i = 0; i < indices.length; i++) - if (size(i) != 1) - ix += indices[i] * stride(i); - if (ix >= data.length()) - throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); - data.put(ix, element.getDouble(0)); + } } - return this; - } - - @Override - public INDArray match(INDArray comp, Condition condition) { - // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition - Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal"); - Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal"); - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition)); - } - - @Override - public INDArray match(Number comp, Condition condition) { - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp.doubleValue(), condition)); - } - - @Override - public INDArray getWhere(INDArray comp, Condition condition) { - return BooleanIndexing.chooseFrom(new INDArray[]{this,comp},condition); - } - - @Override - public INDArray getWhere(Number comp, Condition condition) { - return BooleanIndexing.chooseFrom(new INDArray[]{this}, Collections.singletonList(comp.doubleValue()),Collections.emptyList(),condition); - } - - @Override - public INDArray putWhere(INDArray comp, INDArray put, Condition condition) { - Nd4j.getCompressor().autoDecompress(this); - MatchConditionTransform matchCondition = new MatchConditionTransform(this,comp,condition); - Nd4j.getExecutioner().exec(matchCondition); - return putWhereWithMask(matchCondition.z(),put); - } - - @Override - public INDArray putWhere(Number comp, INDArray put, Condition condition) { - return putWhere(Nd4j.scalar(comp),put,condition); - } - - @Override - public INDArray putWhere(Number comp, Number put, Condition condition) { - return putWhere(Nd4j.scalar(comp),Nd4j.scalar(put),condition); - } - - - @Override - public INDArray putWhereWithMask(INDArray mask, INDArray put) { - INDArray output = dup(); - Nd4j.getExecutioner().execAndReturn(new Where(new INDArray[]{mask,this,put},new INDArray[]{output})); - return output; - } - - @Override - public INDArray putWhereWithMask(INDArray mask, Number put) { - return putWhereWithMask(mask,Nd4j.scalar(put)); - } - - @Override - public INDArray put(int i, int j, INDArray element) { - return put(new int[] {i, j}, element); - } - - @Override - public INDArray put(int i, int j, Number element) { - return putScalar(new int[] {i, j}, element.doubleValue()); - } - - @Override - public INDArray putSlice(int slice, INDArray put) { - Nd4j.getCompressor().autoDecompress(this); - - - if (isScalar()) { - Preconditions.checkState(put.isScalar(), "Invalid dimension. Can only insert a scalar in to another scalar"); - put(0, put.getScalar(0)); - return this; - } else if (isVector()) { - Preconditions.checkState(put.isVectorOrScalar() && put.length() == length(), - "Invalid dimension on insertion. Can only insert scalars/vectors into other scalar/vectors"); - if (put.isScalar()) - putScalar(slice, put.getDouble(0)); - else - for (int i = 0; i < length(); i++) - putScalar(i, put.getDouble(i)); - return this; + if (ix >= data.length()) { + throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); } + data.put(ix, element.getDouble(0)); + } + return this; + } - assertSlice(put, slice); + @Override + public INDArray match(INDArray comp, Condition condition) { + // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition + Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal"); + Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal"); + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, + Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition)); + } + + @Override + public INDArray match(Number comp, Condition condition) { + return Nd4j.getExecutioner() + .exec(new MatchConditionTransform(this, comp.doubleValue(), condition)); + } + + @Override + public INDArray getWhere(INDArray comp, Condition condition) { + return BooleanIndexing.chooseFrom(new INDArray[]{this, comp}, condition); + } + + @Override + public INDArray getWhere(Number comp, Condition condition) { + return BooleanIndexing.chooseFrom(new INDArray[]{this}, + Collections.singletonList(comp.doubleValue()), Collections.emptyList(), condition); + } + + @Override + public INDArray putWhere(INDArray comp, INDArray put, Condition condition) { + Nd4j.getCompressor().autoDecompress(this); + MatchConditionTransform matchCondition = new MatchConditionTransform(this, comp, condition); + Nd4j.getExecutioner().exec(matchCondition); + return putWhereWithMask(matchCondition.z(), put); + } + + @Override + public INDArray putWhere(Number comp, INDArray put, Condition condition) { + return putWhere(Nd4j.scalar(comp), put, condition); + } + + @Override + public INDArray putWhere(Number comp, Number put, Condition condition) { + return putWhere(Nd4j.scalar(comp), Nd4j.scalar(put), condition); + } - INDArray view = slice(slice); + @Override + public INDArray putWhereWithMask(INDArray mask, INDArray put) { + INDArray output = dup(); + Nd4j.getExecutioner() + .execAndReturn(new Where(new INDArray[]{mask, this, put}, new INDArray[]{output})); + return output; + } - if (put.length() == 1) { + @Override + public INDArray putWhereWithMask(INDArray mask, Number put) { + return putWhereWithMask(mask, Nd4j.scalar(put)); + } + + @Override + public INDArray put(int i, int j, INDArray element) { + return put(new int[]{i, j}, element); + } + + @Override + public INDArray put(int i, int j, Number element) { + return putScalar(new int[]{i, j}, element.doubleValue()); + } + + @Override + public INDArray putSlice(int slice, INDArray put) { + Nd4j.getCompressor().autoDecompress(this); + + if (isScalar()) { + Preconditions.checkState(put.isScalar(), + "Invalid dimension. Can only insert a scalar in to another scalar"); + put(0, put.getScalar(0)); + return this; + } else if (isVector()) { + Preconditions.checkState(put.isVectorOrScalar() && put.length() == length(), + "Invalid dimension on insertion. Can only insert scalars/vectors into other scalar/vectors"); + if (put.isScalar()) { putScalar(slice, put.getDouble(0)); } else { - if(!(view.isVector() && put.isVector() && view.length() == put.length()) && !view.equalShapes(put)){ - throw new IllegalStateException("Cannot put slice: array to be put (" + Arrays.toString(put.shape()) + - ") and slice array (" + Arrays.toString(view.shape()) + ") have different shapes"); - } - view.assign(put); - } - return this; - } - - protected void assertSlice(INDArray put, long slice) { - Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices()); - long[] sliceShape = put.shape(); - if (Shape.isRowVectorShape(sliceShape)) { - } else { - long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); - - //no need to compare for scalar; primarily due to shapes either being [1] or length 0 - if (put.isScalar()) - return; - - if (isVector() && put.isVector() && put.length() < length()) - return; - //edge case for column vectors - if (Shape.isColumnVectorShape(sliceShape)) - return; - if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) - && !Shape.isRowVectorShape(sliceShape)) - throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", - Arrays.toString(sliceShape), Arrays.toString(requiredShape))); - } - } - - public boolean isMatrix() { - return rank() == 2; - } - - protected INDArray newShape(long[] newShape, char ordering) { - - return Nd4j.create(data(), newShape, stride(), 0, ordering); - } - - protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset, char ordering) { - return Nd4j.create(data, newShape, newStrides, offset, ordering); - } - - protected INDArray create(DataBuffer data, long[] newShape, long[] newStrides, long offset, char ordering) { - return Nd4j.create(data, newShape, newStrides, offset, ordering); - } - - protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset) { - return Nd4j.create(data, newShape, newStrides, offset); - } - - protected INDArray create(int[] shape) { - return Nd4j.create(shape, getStrides(shape, Nd4j.order()), 0); - } - - protected INDArray create(int[] shape, int[] strides, long offset) { - return Nd4j.create(shape, strides, offset); - } - - protected int[] getStrides(int[] shape, char ordering) { - return Nd4j.getStrides(shape, ordering); - } - - @Override - public double squaredDistance(INDArray other) { - validateNumericalArray("squaredDistance", false); - double d2 = distance2(other); - return d2 * d2; - } - - @Override - public double distance2(INDArray other) { - validateNumericalArray("distance2", false); - Nd4j.getCompressor().autoDecompress(this); - return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue(); - } - - @Override - public double distance1(INDArray other) { - validateNumericalArray("distance1", false); - Nd4j.getCompressor().autoDecompress(this); - return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue(); - } - - @Override - public INDArray get(INDArray indices) { - if(indices.rank() > 2) { - throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); - } - - if (rank() == 1) { - Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well"); - val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); - for (int e = 0; e < indices.length(); e++) { - val idx = indices.getLong(e); - val value = getDouble(idx); - ret.putScalar(e, value); - } - - return ret; - } else if(indices.rows() == rank()) { - INDArray ret = Nd4j.create(this.dataType(), indices.columns()); - - for(int i = 0; i < indices.columns(); i++) { - int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); - val v = getDouble(specifiedIndex); - ret.putScalar(i, v); - } - - return ret; - } - else { - List arrList = new ArrayList<>(); - - if(indices.isMatrix() || indices.isColumnVector() - || (indices.isScalar() && indices.rank() == 2)) { // we need this for compatibility with legacy code - for(int i = 0; i < indices.rows(); i++) { - if(i == 0) { - INDArray row = indices.getRow(i); - for(int j = 0; j < row.length(); j++) { - arrList.add(slice(row.getInt(j))); - } - } - else { - INDArray row = indices.slice(i); - for(int j = 0; j < row.length(); j++) { - INDArray put = arrList.get(j).slice(row.getInt(j)); - put = put.reshape(Longs.concat(new long[]{1},put.shape())); - arrList.set(j,put); - } - } - - } - } - else if(indices.isRowVector()) { - for(int i = 0; i < indices.length(); i++) { - INDArray add = slice(indices.getInt(i)); - add = add.reshape(Longs.concat(new long[] {1,},add.shape())); - arrList.add(add); - } - } - - return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()])); - - } - - - } - - @Override - public INDArray put(INDArray indices, INDArray element) { - if(indices.rank() > 2) { - throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); - } - - if(indices.rows() == rank()) { - NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape()); - for(int i = 0; i < indices.columns(); i++) { - int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); - putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next())); + for (int i = 0; i < length(); i++) { + putScalar(i, put.getDouble(i)); } } - else { - List arrList = new ArrayList<>(); - - if(indices.isMatrix() || indices.isColumnVector()) { - for(int i = 0; i < indices.rows(); i++) { - INDArray row = indices.getRow(i); - for(int j = 0; j < row.length(); j++) { - INDArray slice = slice(row.getInt(j)); - Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice})); - arrList.add(slice(row.getInt(j))); - } - } - } - else if(indices.isRowVector()) { - for(int i = 0; i < indices.length(); i++) { - arrList.add(slice(indices.getInt(i))); - } - } - } - return this; + return this; } - @Override - public INDArray put(INDArrayIndex[] indices, INDArray element) { - Nd4j.getCompressor().autoDecompress(this); - boolean isSpecifiedIndex = false; - for(INDArrayIndex idx : indices){ - if(idx instanceof SpecifiedIndex){ - isSpecifiedIndex = true; - break; - } + assertSlice(put, slice); + + INDArray view = slice(slice); + + if (put.length() == 1) { + putScalar(slice, put.getDouble(0)); + } else { + if (!(view.isVector() && put.isVector() && view.length() == put.length()) + && !view.equalShapes(put)) { + throw new IllegalStateException( + "Cannot put slice: array to be put (" + Arrays.toString(put.shape()) + + ") and slice array (" + Arrays.toString(view.shape()) + ") have different shapes"); + } + view.assign(put); + } + return this; + } + + protected void assertSlice(INDArray put, long slice) { + Preconditions.checkArgument(slice < slices(), + "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", + slice, slices()); + long[] sliceShape = put.shape(); + if (Shape.isRowVectorShape(sliceShape)) { + } else { + long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); + + //no need to compare for scalar; primarily due to shapes either being [1] or length 0 + if (put.isScalar()) { + return; } - if(!isSpecifiedIndex){ - return get(indices).assign(element); - } else { - //Can't get a view, so we'll do it in subsets instead - // This is inefficient, but it is correct... - int numSpecified = 0; - List specifiedIdxs = new ArrayList<>(); - List specifiedIdxDims = new ArrayList<>(); - - INDArrayIndex[] destinationIndices = indices.clone(); //Shallow clone - INDArrayIndex[] sourceIndices = indices.clone(); - for( int i=0; i can't use point(1) on [1,x,y] - sourceIndices[i] = NDArrayIndex.point(0); - } - } - int[] counts = new int[specifiedIdxs.size()]; - int[] dims = new int[specifiedIdxDims.size()]; - for( int i=0; i 2) { + throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); } - @Override - public INDArray put(INDArrayIndex[] indices, Number element) { - Nd4j.getCompressor().autoDecompress(this); - INDArray get = get(indices); - for (int i = 0; i < get.length(); i++) - get.putScalar(i, element.doubleValue()); - return this; - } + if (rank() == 1) { + Preconditions.checkArgument(indices.rank() <= 1, + "For 1D vector indices must be either scalar or vector as well"); + val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); + for (int e = 0; e < indices.length(); e++) { + val idx = indices.getLong(e); + val value = getDouble(idx); + ret.putScalar(e, value); + } + + return ret; + } else if (indices.rows() == rank()) { + INDArray ret = Nd4j.create(this.dataType(), indices.columns()); + + for (int i = 0; i < indices.columns(); i++) { + int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); + val v = getDouble(specifiedIndex); + ret.putScalar(i, v); + } + + return ret; + } else { + List arrList = new ArrayList<>(); + + if (indices.isMatrix() || indices.isColumnVector() + || (indices.isScalar() + && indices.rank() == 2)) { // we need this for compatibility with legacy code + for (int i = 0; i < indices.rows(); i++) { + if (i == 0) { + INDArray row = indices.getRow(i); + for (int j = 0; j < row.length(); j++) { + arrList.add(slice(row.getInt(j))); + } + } else { + INDArray row = indices.slice(i); + for (int j = 0; j < row.length(); j++) { + INDArray put = arrList.get(j).slice(row.getInt(j)); + put = put.reshape(Longs.concat(new long[]{1}, put.shape())); + arrList.set(j, put); + } + } + + } + } else if (indices.isRowVector()) { + for (int i = 0; i < indices.length(); i++) { + INDArray add = slice(indices.getInt(i)); + add = add.reshape(Longs.concat(new long[]{1,}, add.shape())); + arrList.add(add); + } + } + + return Nd4j.concat(0, arrList.toArray(new INDArray[arrList.size()])); - @Override - public INDArray swapAxes(int dimension, int with) { - int[] shape = ArrayUtil.range(0, shape().length); - shape[dimension] = with; - shape[with] = dimension; - return permute(shape); } - @Override - public boolean isView() { + } + + @Override + public INDArray put(INDArray indices, INDArray element) { + if (indices.rank() > 2) { + throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); + } + + if (indices.rows() == rank()) { + NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape()); + for (int i = 0; i < indices.columns(); i++) { + int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); + putScalar(specifiedIndex, element.getDouble(ndIndexIterator.next())); + } + } else { + List arrList = new ArrayList<>(); + + if (indices.isMatrix() || indices.isColumnVector()) { + for (int i = 0; i < indices.rows(); i++) { + INDArray row = indices.getRow(i); + for (int j = 0; j < row.length(); j++) { + INDArray slice = slice(row.getInt(j)); + Nd4j.getExecutioner() + .execAndReturn(new Assign(new INDArray[]{slice, element}, new INDArray[]{slice})); + arrList.add(slice(row.getInt(j))); + } + } + } else if (indices.isRowVector()) { + for (int i = 0; i < indices.length(); i++) { + arrList.add(slice(indices.getInt(i))); + } + } + } + return this; + } + + @Override + public INDArray put(INDArrayIndex[] indices, INDArray element) { + Nd4j.getCompressor().autoDecompress(this); + boolean isSpecifiedIndex = false; + for (INDArrayIndex idx : indices) { + if (idx instanceof SpecifiedIndex) { + isSpecifiedIndex = true; + break; + } + } + + if (!isSpecifiedIndex) { + return get(indices).assign(element); + } else { + //Can't get a view, so we'll do it in subsets instead + // This is inefficient, but it is correct... + int numSpecified = 0; + List specifiedIdxs = new ArrayList<>(); + List specifiedIdxDims = new ArrayList<>(); + + INDArrayIndex[] destinationIndices = indices.clone(); //Shallow clone + INDArrayIndex[] sourceIndices = indices.clone(); + for (int i = 0; i < indices.length; i++) { + INDArrayIndex idx = indices[i]; + if (idx instanceof SpecifiedIndex) { + numSpecified++; + long[] idxs = ((SpecifiedIndex) idx).getIndexes(); + specifiedIdxs.add(idxs); + specifiedIdxDims.add(i); + } else if (idx instanceof PointIndex) { + //Example: [2,3,3].put(point(1), ..., [1,x,y]) -> can't use point(1) on [1,x,y] + sourceIndices[i] = NDArrayIndex.point(0); + } + } + int[] counts = new int[specifiedIdxs.size()]; + int[] dims = new int[specifiedIdxDims.size()]; + for (int i = 0; i < specifiedIdxs.size(); i++) { + counts[i] = specifiedIdxs.get(i).length; + dims[i] = specifiedIdxDims.get(i); + } + + NdIndexIterator iter = new NdIndexIterator(counts); + while (iter.hasNext()) { + long[] iterationIdxs = iter.next(); + for (int i = 0; i < iterationIdxs.length; i++) { + long[] indicesForDim = specifiedIdxs.get(i); + destinationIndices[dims[i]] = NDArrayIndex.point(indicesForDim[(int) iterationIdxs[i]]); + sourceIndices[dims[i]] = NDArrayIndex.point(iterationIdxs[i]); + } + + INDArray sourceView = element.get(sourceIndices); + INDArray destinationView = this.get(destinationIndices); + destinationView.assign(sourceView); + } + } + return this; + } + + @Override + public INDArray put(INDArrayIndex[] indices, Number element) { + Nd4j.getCompressor().autoDecompress(this); + INDArray get = get(indices); + for (int i = 0; i < get.length(); i++) { + get.putScalar(i, element.doubleValue()); + } + return this; + } + + @Override + public INDArray swapAxes(int dimension, int with) { + int[] shape = ArrayUtil.range(0, shape().length); + shape[dimension] = with; + shape[with] = dimension; + return permute(shape); + } + + + @Override + public boolean isView() { /* We don't really use Shape offset value anywhere And it's possible to be not a view, and have non-empty originalBuffer */ - // length/data.length can be different in case of Threshold conversion - if(isEmpty() || isS()) - return false; + // length/data.length can be different in case of Threshold conversion + if (isEmpty() || isS()) { + return false; + } - val c2 = (length() < data().length() && data.dataType() != DataType.INT); - val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer()); + val c2 = (length() < data().length() && data.dataType() != DataType.INT); + val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer()); - return c2 || c3; + return c2 || c3; + } + + @Override + public boolean isSparse() { + return false; + } + + @Override + public DataBuffer data() { + return data; + } + + @Override + public void setData(DataBuffer data) { + this.data = data; + } + + @Override + public long slices() { + return size(0); + } + + protected INDArray create(DataBuffer buffer) { + return Nd4j.create(buffer); + } + + @Override + public INDArray cond(Condition condition) { + if (isEmpty()) { + return Nd4j.empty(DataType.BOOL); + } + INDArray ret = Nd4j.createUninitialized(DataType.BOOL, this.shape()); + Nd4j.getExecutioner().exec(new MatchConditionTransform(this, ret, condition)); + return ret; + } + + protected void init(int[] shape, int[] stride) { + //null character + if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { + //Shape.setOrder(shapeInfo(), Nd4j.order()); + val si = Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 1, + Nd4j.order(), this.dataType(), false); + setShapeInformation(si); } - @Override - public boolean isSparse() { - return false; + } + + protected void init(long[] shape, long[] stride) { + //null character + if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { + val si = Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, 1, Nd4j.order(), this.dataType(), false); + setShapeInformation(si); } - @Override - public DataBuffer data() { - return data; + } + + @Override + public INDArray getScalar(long i) { + if (i >= this.length()) { + throw new ND4JIllegalStateException("Index can't be greater then array length"); + } + + if (i < 0) { + i += this.length(); + } + + long idx = this.isScalar() ? 0 + : Shape.getOffset(jvmShapeInfo.javaShapeInformation, Shape.ind2subC(this.shape(), i)); + val buffer = Nd4j.createBuffer(this.data(), this.data().originalOffset() + idx, 1); + val shape = Nd4j.getShapeInfoProvider() + .createShapeInformation(new long[0], new long[0], 1, 'c', dataType(), false); + return Nd4j.createArrayFromShapeBuffer(buffer, shape); + } + + /** + * Do a row wise op (a,s,m,d) a : add s : subtract m : multiply d : divide h : reverse subtraction + * t : reverse division + * + * @param columnVector the column vector + * @param operation the operation + * @return + */ + protected INDArray doColumnWise(INDArray columnVector, char operation) { + Nd4j.getCompressor().autoDecompress(this); + if (columnVector.isScalar()) { + switch (operation) { + case 'a': + addi(columnVector.getDouble(0)); + break; + case 'p': + assign(columnVector.getDouble(0)); + break; + case 's': + subi(columnVector.getDouble(0)); + break; + case 'm': + muli(columnVector.getDouble(0)); + break; + case 'd': + divi(columnVector.getDouble(0)); + break; + case 'h': + rsubi(columnVector.getDouble(0)); + break; + case 't': + rdivi(columnVector.getDouble(0)); + break; + + } + + return this; + } else if (isScalar()) { + switch (operation) { + case 'a': + return columnVector.addi(getDouble(0)); + case 'p': + return columnVector.assign(getDouble(0)); + case 's': + return columnVector.subi(getDouble(0)); + case 'm': + return columnVector.muli(getDouble(0)); + case 'd': + return columnVector.divi(getDouble(0)); + case 'h': + return columnVector.rsubi(getDouble(0)); + case 't': + return columnVector.rdivi(getDouble(0)); + + } } - @Override - public void setData(DataBuffer data) { - this.data = data; + //Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0) + //Or, simply require it to be a rank 1 vector + if ((!columnVector.isColumnVector() && columnVector.rank() > 1) + || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { + throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")"); } - @Override - public long slices() { - return size(0); + if (columnVector.data().sameUnderlyingData(data())) { + return doColumnWise(columnVector.dup(), operation); + } + if (equalShapes(columnVector)) { + switch (operation) { + case 'a': + addi(columnVector); + break; + case 'p': + assign(columnVector); + break; + case 's': + subi(columnVector); + break; + case 'm': + muli(columnVector); + break; + case 'd': + divi(columnVector); + break; + case 'h': + rsubi(columnVector); + break; + case 't': + rdivi(columnVector); + break; + } + + return this; } - - protected INDArray create(DataBuffer buffer) { - return Nd4j.create(buffer); - } - - @Override - public INDArray cond(Condition condition) { - if(isEmpty()) - return Nd4j.empty(DataType.BOOL); - INDArray ret = Nd4j.createUninitialized(DataType.BOOL, this.shape()); - Nd4j.getExecutioner().exec(new MatchConditionTransform(this,ret, condition)); - return ret; - } - - protected void init(int[] shape, int[] stride) { - //null character - if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { - //Shape.setOrder(shapeInfo(), Nd4j.order()); - val si = Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 1, Nd4j.order(), this.dataType(), false); - setShapeInformation(si); - } - - } - - protected void init(long[] shape, long[] stride) { - //null character - if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { - val si = Nd4j.getShapeInfoProvider().createShapeInformation(shape,stride, 1, Nd4j.order(), this.dataType(), false); - setShapeInformation(si); - } - - } - - @Override - public INDArray getScalar(long i) { - if (i >= this.length()) - throw new ND4JIllegalStateException("Index can't be greater then array length"); - - if (i < 0) - i += this.length(); - - long idx = this.isScalar() ? 0 : Shape.getOffset(jvmShapeInfo.javaShapeInformation, Shape.ind2subC(this.shape(), i)); - val buffer = Nd4j.createBuffer( this.data(), this.data().originalOffset() + idx, 1); - val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', dataType(), false); - return Nd4j.createArrayFromShapeBuffer(buffer, shape); - } - - /** - * Do a row wise op (a,s,m,d) - * a : add - * s : subtract - * m : multiply - * d : divide - * h : reverse subtraction - * t : reverse division - * - * @param columnVector the column vector - * @param operation the operation - * @return - */ - protected INDArray doColumnWise(INDArray columnVector, char operation) { - Nd4j.getCompressor().autoDecompress(this); - if(columnVector.isScalar()) { - switch (operation) { - case 'a': - addi(columnVector.getDouble(0)); - break; - case 'p': - assign(columnVector.getDouble(0)); - break; - case 's': - subi(columnVector.getDouble(0)); - break; - case 'm': - muli(columnVector.getDouble(0)); - break; - case 'd': - divi(columnVector.getDouble(0)); - break; - case 'h': - rsubi(columnVector.getDouble(0)); - break; - case 't': - rdivi(columnVector.getDouble(0)); - break; - - } - - return this; - } - - else if(isScalar()) { - switch (operation) { - case 'a': - return columnVector.addi(getDouble(0)); - case 'p': - return columnVector.assign(getDouble(0)); - case 's': - return columnVector.subi(getDouble(0)); - case 'm': - return columnVector.muli(getDouble(0)); - case 'd': - return columnVector.divi(getDouble(0)); - case 'h': - return columnVector.rsubi(getDouble(0)); - case 't': - return columnVector.rdivi(getDouble(0)); - - } - } - - //Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0) - //Or, simply require it to be a rank 1 vector - if ((!columnVector.isColumnVector() && columnVector.rank() > 1) || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { - throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) - + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")"); - } - - if (columnVector.data().sameUnderlyingData(data())) - return doColumnWise(columnVector.dup(), operation); - if (equalShapes(columnVector)) { - switch (operation) { - case 'a': - addi(columnVector); - break; - case 'p': - assign(columnVector); - break; - case 's': - subi(columnVector); - break; - case 'm': - muli(columnVector); - break; - case 'd': - divi(columnVector); - break; - case 'h': - rsubi(columnVector); - break; - case 't': - rdivi(columnVector); - break; - } - - return this; - } - if (rows() == 1 && columnVector.isScalar()) { - applyScalarOp(columnVector, operation); - } else { - // special optimization case, broadcast turns into ScalarOp Along Dimension - if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'c' && columnVector.elementWiseStride() == 1) { - switch (operation) { - case 'a': { - ScalarAdd op = new ScalarAdd(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'p': { - ScalarSet op = new ScalarSet(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 's': { - ScalarSubtraction op = new ScalarSubtraction(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'm': { - ScalarMultiplication op = - new ScalarMultiplication(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'd': { - ScalarDivision op = new ScalarDivision(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'h': { - ScalarReverseSubtraction op = - new ScalarReverseSubtraction(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 't': { - ScalarReverseDivision op = - new ScalarReverseDivision(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - } - } else { - applyBroadcastOp(columnVector, operation); - } - - } - - return this; - - } - - /** - * Do a row wise op (a,s,m,d) - * a : add - * s : subtract - * m : multiply - * d : divide - * h : reverse subtraction - * t : reverse division - * - * @param rowVector the row vector - * @param operation the operation - * @return - */ - protected INDArray doRowWise(INDArray rowVector, final char operation) { - Nd4j.getCompressor().autoDecompress(this); - - - if(rowVector.isScalar()) { - switch (operation) { - case 'a': - addi(rowVector.getDouble(0)); - break; - case 'p': - assign(rowVector.getDouble(0)); - break; - case 's': - subi(rowVector.getDouble(0)); - break; - case 'm': - muli(rowVector.getDouble(0)); - break; - case 'd': - divi(rowVector.getDouble(0)); - break; - case 'h': - rsubi(rowVector.getDouble(0)); - break; - case 't': - rdivi(rowVector.getDouble(0)); - break; - - } - - return this; - } - else if(isScalar()) { - switch (operation) { - case 'a': - return rowVector.addi(getDouble(0)); - case 'p': - return rowVector.assign(getDouble(0)); - case 's': - return rowVector.subi(getDouble(0)); - case 'm': - return rowVector.muli(getDouble(0)); - case 'd': - return rowVector.divi(getDouble(0)); - case 'h': - return rowVector.rsubi(getDouble(0)); - case 't': - return rowVector.rdivi(getDouble(0)); - - } - } - - //Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1) - if (!rowVector.isRowVector() || this.rank() > 1 && rowVector.rank() > 1 && this.size(1) != rowVector.size(1) || rowVector.length() <= 1) { - throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) - + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")"); - } - - if (rowVector.data().sameUnderlyingData(data())) - return doRowWise(rowVector.dup(), operation); - - if (isVector()) { - switch (operation) { - case 'a': - addi(rowVector); - break; - case 'p': - assign(rowVector); - break; - case 's': - subi(rowVector); - break; - case 'm': - muli(rowVector); - break; - case 'd': - divi(rowVector); - break; - case 'h': - rsubi(rowVector); - break; - case 't': - rdivi(rowVector); - break; - } - - return this; - } - - if (rank() == 2 && columns() == 1 && rowVector.isScalar()) { - applyScalarOp(rowVector, operation); - } else { - // special optimization case, broadcast turns into ScalarOp Along Dimension - if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'f' && rowVector.elementWiseStride() == 1) { - switch (operation) { - case 'a': { - ScalarAdd op = new ScalarAdd(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'p': { - ScalarSet op = new ScalarSet(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 's': { - ScalarSubtraction op = new ScalarSubtraction(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'm': { - ScalarMultiplication op = new ScalarMultiplication(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'd': { - ScalarDivision op = new ScalarDivision(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'h': { - ScalarReverseSubtraction op = - new ScalarReverseSubtraction(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 't': { - ScalarReverseDivision op = new ScalarReverseDivision(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - - } - } else { - applyBroadcastOp(rowVector, operation); - } - } - - return this; - } - - - private void applyBroadcastOp(INDArray vector, final char operation) { - Nd4j.getCompressor().autoDecompress(this); - int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0; - - // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics - if (this.data() == vector.data()) - vector = vector.dup(); + if (rows() == 1 && columnVector.isScalar()) { + applyScalarOp(columnVector, operation); + } else { + // special optimization case, broadcast turns into ScalarOp Along Dimension + if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'c' + && columnVector.elementWiseStride() == 1) { switch (operation) { - case 'a': - Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension)); - return; - case 's': - Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension)); - return; - case 'm': - Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension)); - return; - case 'd': - Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension)); - return; - case 'h': - Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension)); - return; - case 't': - Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension)); - return; - case 'p': - Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension)); - return; - default: - throw new UnsupportedOperationException("Unknown operation: " + operation); + case 'a': { + ScalarAdd op = new ScalarAdd(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'p': { + ScalarSet op = new ScalarSet(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 's': { + ScalarSubtraction op = new ScalarSubtraction(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'm': { + ScalarMultiplication op = + new ScalarMultiplication(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'd': { + ScalarDivision op = new ScalarDivision(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'h': { + ScalarReverseSubtraction op = + new ScalarReverseSubtraction(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 't': { + ScalarReverseDivision op = + new ScalarReverseDivision(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } } + } else { + applyBroadcastOp(columnVector, operation); + } + } - private void applyScalarOp(INDArray vector, char operation) { - Nd4j.getCompressor().autoDecompress(this); + return this; + + } + + /** + * Do a row wise op (a,s,m,d) a : add s : subtract m : multiply d : divide h : reverse subtraction + * t : reverse division + * + * @param rowVector the row vector + * @param operation the operation + * @return + */ + protected INDArray doRowWise(INDArray rowVector, final char operation) { + Nd4j.getCompressor().autoDecompress(this); + + if (rowVector.isScalar()) { + switch (operation) { + case 'a': + addi(rowVector.getDouble(0)); + break; + case 'p': + assign(rowVector.getDouble(0)); + break; + case 's': + subi(rowVector.getDouble(0)); + break; + case 'm': + muli(rowVector.getDouble(0)); + break; + case 'd': + divi(rowVector.getDouble(0)); + break; + case 'h': + rsubi(rowVector.getDouble(0)); + break; + case 't': + rdivi(rowVector.getDouble(0)); + break; + + } + + return this; + } else if (isScalar()) { + switch (operation) { + case 'a': + return rowVector.addi(getDouble(0)); + case 'p': + return rowVector.assign(getDouble(0)); + case 's': + return rowVector.subi(getDouble(0)); + case 'm': + return rowVector.muli(getDouble(0)); + case 'd': + return rowVector.divi(getDouble(0)); + case 'h': + return rowVector.rsubi(getDouble(0)); + case 't': + return rowVector.rdivi(getDouble(0)); + + } + } + + //Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1) + if (!rowVector.isRowVector() + || this.rank() > 1 && rowVector.rank() > 1 && this.size(1) != rowVector.size(1) + || rowVector.length() <= 1) { + throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")"); + } + + if (rowVector.data().sameUnderlyingData(data())) { + return doRowWise(rowVector.dup(), operation); + } + + if (isVector()) { + switch (operation) { + case 'a': + addi(rowVector); + break; + case 'p': + assign(rowVector); + break; + case 's': + subi(rowVector); + break; + case 'm': + muli(rowVector); + break; + case 'd': + divi(rowVector); + break; + case 'h': + rsubi(rowVector); + break; + case 't': + rdivi(rowVector); + break; + } + + return this; + } + + if (rank() == 2 && columns() == 1 && rowVector.isScalar()) { + applyScalarOp(rowVector, operation); + } else { + // special optimization case, broadcast turns into ScalarOp Along Dimension + if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'f' + && rowVector.elementWiseStride() == 1) { switch (operation) { - case 'a': - addi(vector.getDouble(0)); - break; - case 's': - subi(vector.getDouble(0)); - break; - case 'm': - muli(vector.getDouble(0)); - break; - case 'd': - divi(vector.getDouble(0)); - break; - case 'h': - rsubi(vector.getDouble(0)); - break; - case 't': - rdivi(vector.getDouble(0)); - break; + case 'a': { + ScalarAdd op = new ScalarAdd(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'p': { + ScalarSet op = new ScalarSet(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 's': { + ScalarSubtraction op = new ScalarSubtraction(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'm': { + ScalarMultiplication op = new ScalarMultiplication(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'd': { + ScalarDivision op = new ScalarDivision(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'h': { + ScalarReverseSubtraction op = + new ScalarReverseSubtraction(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 't': { + ScalarReverseDivision op = new ScalarReverseDivision(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + + } + } else { + applyBroadcastOp(rowVector, operation); + } + } + + return this; + } + + + private void applyBroadcastOp(INDArray vector, final char operation) { + Nd4j.getCompressor().autoDecompress(this); + int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0; + + // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics + if (this.data() == vector.data()) { + vector = vector.dup(); + } + switch (operation) { + case 'a': + Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension)); + return; + case 's': + Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension)); + return; + case 'm': + Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension)); + return; + case 'd': + Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension)); + return; + case 'h': + Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension)); + return; + case 't': + Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension)); + return; + case 'p': + Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension)); + return; + default: + throw new UnsupportedOperationException("Unknown operation: " + operation); + } + } + + private void applyScalarOp(INDArray vector, char operation) { + Nd4j.getCompressor().autoDecompress(this); + switch (operation) { + case 'a': + addi(vector.getDouble(0)); + break; + case 's': + subi(vector.getDouble(0)); + break; + case 'm': + muli(vector.getDouble(0)); + break; + case 'd': + divi(vector.getDouble(0)); + break; + case 'h': + rsubi(vector.getDouble(0)); + break; + case 't': + rdivi(vector.getDouble(0)); + break; + } + } + + protected DataBuffer shapeOf() { + // if (shape == null) + // shape = Shape.shapeOf(shapeInfoDataBuffer()); + // return shape; + + return Shape.shapeOf(shapeInfoDataBuffer()); + } + + protected DataBuffer strideOf() { + // if (stride == null) + // stride = Shape.stride(shapeInfoDataBuffer()); + // return stride; + return Shape.stride(shapeInfoDataBuffer()); + } + + @Override + public int stride(int dimension) { + int rank = jvmShapeInfo.rank; + Preconditions.checkArgument(dimension < rank, + "Cannot get stride for dimension %s from rank %s array: " + + "dimension indices must be in range -rank <= dimension < rank", dimension, rank); + if (dimension < 0) { + return (int) stride()[dimension + rank]; + } + return (int) stride()[dimension]; + } + + @Override + public INDArray rdiviColumnVector(INDArray columnVector) { + validateNumericalArray("rdiviColumnVector", false); + return doColumnWise(columnVector, 't'); + } + + @Override + public INDArray rdivColumnVector(INDArray columnVector) { + validateNumericalArray("rdivColumnVector", false); + return dup().rdiviColumnVector(columnVector); + } + + @Override + public INDArray rdiviRowVector(INDArray rowVector) { + validateNumericalArray("rdiviRowVector", false); + return doRowWise(rowVector, 't'); + } + + @Override + public INDArray rdivRowVector(INDArray rowVector) { + validateNumericalArray("rdivRowVector", false); + return dup().rdiviRowVector(rowVector); + } + + @Override + public INDArray rsubiColumnVector(INDArray columnVector) { + validateNumericalArray("rsubiColumnVector", false); + return doColumnWise(columnVector, 'h'); + } + + @Override + public INDArray rsubColumnVector(INDArray columnVector) { + validateNumericalArray("rsubColumnVector", false); + return dup().rsubiColumnVector(columnVector); + } + + @Override + public INDArray rsubiRowVector(INDArray rowVector) { + validateNumericalArray("rsubiRowVector", false); + return doRowWise(rowVector, 'h'); + } + + @Override + public INDArray rsubRowVector(INDArray rowVector) { + validateNumericalArray("rsubRowVector", false); + return dup().rsubiRowVector(rowVector); + } + + @Override + public INDArray put(int i, INDArray element) { + Preconditions.checkArgument(element.isScalar(), + "Element must be a scalar: element has shape %ndShape", element); + return putScalar(i, element.getDouble(0)); + } + + @Override + public INDArray diviColumnVector(INDArray columnVector) { + validateNumericalArray("diviColumnVector", false); + return doColumnWise(columnVector, 'd'); + } + + @Override + public INDArray divColumnVector(INDArray columnVector) { + validateNumericalArray("divColumnVector", false); + return dup().diviColumnVector(columnVector); + } + + @Override + public INDArray diviRowVector(INDArray rowVector) { + validateNumericalArray("diviRowVector", false); + return doRowWise(rowVector, 'd'); + } + + @Override + public INDArray divRowVector(INDArray rowVector) { + validateNumericalArray("divRowVector", false); + return dup().diviRowVector(rowVector); + } + + @Override + public INDArray muliColumnVector(INDArray columnVector) { + validateNumericalArray("muliColumnVector", false); + return doColumnWise(columnVector, 'm'); + } + + @Override + public INDArray mulColumnVector(INDArray columnVector) { + validateNumericalArray("mulColumnVector", false); + return dup().muliColumnVector(columnVector); + } + + @Override + public INDArray muliRowVector(INDArray rowVector) { + validateNumericalArray("muliRowVector", false); + return doRowWise(rowVector, 'm'); + } + + @Override + public INDArray mulRowVector(INDArray rowVector) { + validateNumericalArray("mulRowVector", false); + return dup().muliRowVector(rowVector); + } + + @Override + public INDArray subiColumnVector(INDArray columnVector) { + validateNumericalArray("subiColumnVector", false); + return doColumnWise(columnVector, 's'); + } + + @Override + public INDArray subColumnVector(INDArray columnVector) { + validateNumericalArray("subColumnVector", false); + return dup().subiColumnVector(columnVector); + } + + @Override + public INDArray subiRowVector(INDArray rowVector) { + validateNumericalArray("subiRowVector", false); + return doRowWise(rowVector, 's'); + } + + @Override + public INDArray subRowVector(INDArray rowVector) { + validateNumericalArray("subRowVector", false); + return dup().subiRowVector(rowVector); + } + + @Override + public INDArray addiColumnVector(INDArray columnVector) { + validateNumericalArray("addiColumnVector", false); + return doColumnWise(columnVector, 'a'); + } + + @Override + public INDArray putiColumnVector(INDArray columnVector) { + return doColumnWise(columnVector, 'p'); + } + + @Override + public INDArray addColumnVector(INDArray columnVector) { + validateNumericalArray("addColumnVector", false); + return dup().addiColumnVector(columnVector); + } + + @Override + public INDArray addiRowVector(INDArray rowVector) { + validateNumericalArray("addiRowVector", false); + return doRowWise(rowVector, 'a'); + } + + @Override + public INDArray putiRowVector(INDArray rowVector) { + validateNumericalArray("putiRowVector", false); + return doRowWise(rowVector, 'p'); + } + + @Override + public INDArray addRowVector(INDArray rowVector) { + validateNumericalArray("addRowVector", false); + return dup().addiRowVector(rowVector); + } + + @Override + public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { + return mMulTranspose.exec(this, other, result); + } + + @Override + public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { + return mMulTranspose.exec(this, other, null); + } + + @Override + public INDArray mmul(INDArray other, char resultOrder) { + Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', + "Order must be either 'c' or 'f', but [" + resultOrder + "] was given"); + Preconditions.checkState(this.dataType() == other.dataType(), + "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), + other.dataType()); + // FIXME: add support for 3D+ here? + long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()}; + INDArray result = createUninitialized(this.dataType(), shape, resultOrder); + if (result.isScalar()) { + return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1); + } + return mmuli(other, result); + } + + @Override + public INDArray mmul(INDArray other) { + return mmul(other, + (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c'); + } + + protected INDArray create(int[] shape, char ordering) { + return Nd4j.create(shape, ordering); + } + + @Override + public double[][] toDoubleMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.size(0) > Integer.MAX_VALUE || this.size(1) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + double[][] ret = new double[rows()][columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asDouble(); + } + + return ret; + } + + @Override + public double[] toDoubleVector() { + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + return dup().data().asDouble(); + } + + @Override + public float[] toFloatVector() { + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + return dup().data().asFloat(); + } + + @Override + public float[][] toFloatMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + float[][] ret = new float[(int) rows()][(int) columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asFloat(); + } + + return ret; + } + + @Override + public int[] toIntVector() { + if (isEmpty()) { + return new int[0]; + } + + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + if (isView() || elementWiseStride() != 1) { + return dup().data().asInt(); + } + return data().asInt(); + } + + @Override + public long[] toLongVector() { + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + if (isView() || elementWiseStride() != 1) { + return dup().data().asLong(); + } + return data().asLong(); + } + + @Override + public long[][] toLongMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + long[][] ret = new long[(int) rows()][(int) columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asLong(); + } + + return ret; + } + + @Override + public int[][] toIntMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + int[][] ret = new int[(int) rows()][(int) columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asInt(); + } + + return ret; + } + + /** + * Perform an copy matrix multiplication + * + * @param other the other matrix to perform matrix multiply with + * @param result the result ndarray + * @return the result of the matrix multiplication + */ + @Override + public INDArray mmul(INDArray other, INDArray result) { + return mmuli(other, result); + } + + @Override + public INDArray div(INDArray other) { + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return divi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return divi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + this.shape(), this.ordering())); + } + } + + @Override + public INDArray div(INDArray other, INDArray result) { + validateNumericalArray("div", true); + return divi(other, result); + } + + @Override + public INDArray mul(INDArray other) { + validateNumericalArray("mul", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return muli(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + val z = Nd4j.createUninitialized( + Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), + this.ordering()); + return muli(other, z); + } + } + + @Override + public INDArray mul(INDArray other, INDArray result) { + return muli(other, result); + } + + @Override + public INDArray sub(INDArray other) { + validateNumericalArray("sub", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return subi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return subi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + this.shape(), this.ordering())); + } + } + + @Override + public INDArray sub(INDArray other, INDArray result) { + return subi(other, result); + } + + @Override + public INDArray add(INDArray other) { + validateNumericalArray("add", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return addi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return addi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + this.shape(), this.ordering())); + } + } + + @Override + public INDArray add(INDArray other, INDArray result) { + validateNumericalArray("add", false); + return addi(other, result); + } + + @Override + public INDArray mmuli(INDArray other, MMulTranspose transpose) { + validateNumericalArray("mmuli", false); + return dup().mmuli(other, this, transpose); + } + + @Override + public INDArray mmuli(INDArray other) { + validateNumericalArray("mmuli", false); + return dup().mmuli(other, this); + } + + @Override + public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { + return transpose.exec(this, other, result); + } + + @Override + public INDArray mmuli(INDArray other, INDArray result) { + validateNumericalArray("mmuli", false); + LinAlgExceptions.assertMultiplies(this, other); + if (other.rank() == 1) { + //GEMV edge case + Preconditions.checkState(result.length() == this.size(0) && this.size(1) == other.size(0), + "Invalid matrix multiplication: %ndShape x %ndShape with result shape %ndShape", this, + other, result); + } else { + //Standard case + Preconditions.checkState( + result.rank() == 2 && result.size(0) == this.size(0) && result.size(1) == other.size(1), + "Invalid result array shape: expected shape [%s,%s], got shape %ndShape result array for %ndShape x %ndShape", + this.size(0), other.size(1), result, + this, other); + } + + if (other.isScalar()) { + return muli(other.getDouble(0), result); + } + if (isScalar()) { + return other.muli(getDouble(0), result); + } + + /* check sizes and resize if necessary */ + + if (result == this || result == other) { + /* actually, blas cannot do multiplications in-place. Therefore, we will fake by + * allocating a temporary object on the side and copy the result later. + */ + INDArray temp = Nd4j.create(result.dataType(), result.shape(), + Nd4j.getStrides(result.shape(), 'f'), 'f'); + + if (other.columns() == 1 || other.rank() == 1) { + Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(result), + BlasBufferUtil.getCharForTranspose(this), 1.0, this, other, 0.0, temp); + } else { + Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result), + BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0, + this, other, 0.0, temp); + } + + result.assign(temp); + + + } else { + + //We require that the result array is 'f' (fortran) order + // However, user might have called mmuli with a c order array for the result + // In which case, we need to allocate a temporary f order array, and later do an assign to the real result array + + boolean requiresTemp = + result.ordering() != 'f' || result.isView() || !Shape.hasDefaultStridesForShape(result); + INDArray gemmResultArr; + if (requiresTemp) { + //Can use createUninitialized due to beta==0.0 parameter in gemm + gemmResultArr = Nd4j.createUninitialized(result.dataType(), result.shape(), 'f'); + } else { + gemmResultArr = result; + } + + if (other.columns() == 1 || other.rank() == 1) { + Nd4j.getBlasWrapper().level2().gemv( + ordering(), + BlasBufferUtil.getCharForTranspose(other), + 1.0, + this, + other, + 0.0, + gemmResultArr); + } else { + //gemm doesn't support strides so vectors and views + //don't work + Nd4j.getBlasWrapper().level3().gemm(ordering(), + BlasBufferUtil.getCharForTranspose(other), + BlasBufferUtil.getCharForTranspose(gemmResultArr), + 1.0, + this, + other, + 0.0, + gemmResultArr); + } + + if (requiresTemp) { + result.assign(gemmResultArr); + } + } + + // 1D edge case: reshape back to vector + if (other.rank() == 1) { + result = result.reshape(result.length()); + } + return result; + } + + private INDArray create(int[] shape, int[] stride) { + return Nd4j.create(shape, stride); + } + + @Override + public INDArray divi(INDArray other) { + return divi(other, this); + } + + @Override + public INDArray divi(INDArray other, INDArray result) { + validateNumericalArray("divi", false); + Shape.assertBroadcastable("divi", this, other, result); + Nd4j.exec(new DivOp(this, other, result)); + return result; + } + + @Override + public INDArray muli(INDArray other) { + return muli(other, this); + } + + @Override + public INDArray muli(INDArray other, INDArray result) { + validateNumericalArray("muli", false); + Shape.assertBroadcastable("muli", this, other, result); + Nd4j.exec(new MulOp(this, other, result)); + return result; + } + + @Override + public INDArray subi(INDArray other) { + return subi(other, this); + } + + /** + * in place subtraction of two matrices + * + * @param other the second ndarray to subtract + * @param result the result ndarray + * @return the result of the subtraction + */ + @Override + public INDArray subi(INDArray other, INDArray result) { + validateNumericalArray("subi", false); + Shape.assertBroadcastable("subi", this, other, result); + Nd4j.exec(new SubOp(this, other, result)); + return result; + } + + @Override + public INDArray addi(INDArray other) { + return addi(other, this); + } + + @Override + public INDArray addi(INDArray other, INDArray result) { + validateNumericalArray("addi", false); + Shape.assertBroadcastable("addi", this, other, result); + Nd4j.exec(new AddOp(this, other, result)); + return result; + } + + @Override + public INDArray normmax(boolean keepDims, int... dimension) { + validateNumericalArray("normmax", false); + return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); + } + + @Override + public INDArray normmax(int... dimension) { + return normmax(false, dimension); + } + + @Override + public INDArray rdiv(INDArray other) { + validateNumericalArray("rdiv", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return rdivi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return rdivi(other, this.ulike()); + } + } + + @Override + public INDArray rdivi(INDArray other) { + return rdivi(other, this); + } + + @Override + public INDArray rdiv(INDArray other, INDArray result) { + validateNumericalArray("rdiv", false); + return dup().rdivi(other, result); + } + + @Override + public INDArray rdivi(INDArray other, INDArray result) { + validateNumericalArray("rdivi", false); + Shape.assertBroadcastable("rdivi", this, other, result); + Nd4j.exec(new RDivOp(this, other, result)); + return result; + } + + @Override + public INDArray rsub(INDArray other, INDArray result) { + validateNumericalArray("rsub", false); + return rsubi(other, result); + } + + @Override + public INDArray rsub(INDArray other) { + validateNumericalArray("rsub", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return rsubi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return rsubi(other, this.ulike()); + } + } + + @Override + public INDArray rsubi(INDArray other) { + return rsubi(other, this); + } + + @Override + public INDArray rsubi(INDArray other, INDArray result) { + validateNumericalArray("rsubi", false); + Shape.assertBroadcastable("rsubi", this, other, result); + Nd4j.exec(new RSubOp(this, other, result)); + return result; + } + + @Override + public INDArray assign(Number value) { + Preconditions.checkState( + dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, + "Only values 0 or 1 are allowed for scalar " + + "assign on boolean arrays: got value %s on to assign to boolean array with shape %ndShape", + value, this); + Nd4j.getExecutioner().exec(new ScalarSet(this, value)); + return this; + } + + @Override + public INDArray assign(boolean value) { + return assign(value ? 1 : 0); + } + + @Override + public INDArray assignIf(INDArray arr, Condition condition) { + BooleanIndexing.assignIf(this, arr, condition); + return this; + } + + @Override + public INDArray replaceWhere(INDArray arr, Condition condition) { + Nd4j.getCompressor().autoDecompress(this); + BooleanIndexing.replaceWhere(this, arr, condition); + return this; + } + + @Override + @Deprecated //TODO: Investigate. Not deprecated in the base interface. + public long linearIndex(long i) { + long idx = i; + for (int j = 0; j < jvmShapeInfo.rank - 1; j++) { + if (size((int) i) == 1) { + continue; + } + idx += i * stride(j); + } + return Shape.offset(jvmShapeInfo.javaShapeInformation) + (idx); + } + + @Override + public INDArray slice(long slice) { + Nd4j.getCompressor().autoDecompress(this); + + long slices = slices(); + if (slice >= slices) { + throw new IllegalArgumentException("Illegal slice " + slice); + } + + if (jvmShapeInfo.rank == 0) { + throw new IllegalArgumentException("Can't slice a 0-d NDArray"); + } + + if (slice < 0) { + slice += rank(); + } + INDArrayIndex[] indexes = new INDArrayIndex[rank()]; + indexes[0] = NDArrayIndex.point(slice); + for (int i = 1; i < rank(); i++) { + indexes[i] = NDArrayIndex.all(); + } + return get(indexes); + } + + + protected INDArray createScalarForIndex(long i, boolean applyOffset) { + if (isVector()) { + return getScalar(i); + } + return Nd4j.create(data(), new long[]{1, 1}, new long[]{1, 1}, i); + } + + protected INDArray createScalar(double d) { + return Nd4j.scalar(d); + } + + @Override + public int getTrailingOnes() { + int numLeadingOnes = 0; + for (int i = rank() - 1; i > 0; i--) { + if (size(i) == 1) { + numLeadingOnes++; } } - protected DataBuffer shapeOf() { - // if (shape == null) - // shape = Shape.shapeOf(shapeInfoDataBuffer()); - // return shape; + return numLeadingOnes; + } - return Shape.shapeOf(shapeInfoDataBuffer()); - } - - protected DataBuffer strideOf() { - // if (stride == null) - // stride = Shape.stride(shapeInfoDataBuffer()); - // return stride; - return Shape.stride(shapeInfoDataBuffer()); - } - - @Override - public int stride(int dimension) { - int rank = jvmShapeInfo.rank; - Preconditions.checkArgument(dimension < rank, "Cannot get stride for dimension %s from rank %s array: " + - "dimension indices must be in range -rank <= dimension < rank", dimension, rank); - if (dimension < 0) - return (int) stride()[dimension + rank]; - return (int) stride()[dimension]; - } - - @Override - public INDArray rdiviColumnVector(INDArray columnVector) { - validateNumericalArray("rdiviColumnVector", false); - return doColumnWise(columnVector, 't'); - } - - @Override - public INDArray rdivColumnVector(INDArray columnVector) { - validateNumericalArray("rdivColumnVector", false); - return dup().rdiviColumnVector(columnVector); - } - - @Override - public INDArray rdiviRowVector(INDArray rowVector) { - validateNumericalArray("rdiviRowVector", false); - return doRowWise(rowVector, 't'); - } - - @Override - public INDArray rdivRowVector(INDArray rowVector) { - validateNumericalArray("rdivRowVector", false); - return dup().rdiviRowVector(rowVector); - } - - @Override - public INDArray rsubiColumnVector(INDArray columnVector) { - validateNumericalArray("rsubiColumnVector", false); - return doColumnWise(columnVector, 'h'); - } - - @Override - public INDArray rsubColumnVector(INDArray columnVector) { - validateNumericalArray("rsubColumnVector", false); - return dup().rsubiColumnVector(columnVector); - } - - @Override - public INDArray rsubiRowVector(INDArray rowVector) { - validateNumericalArray("rsubiRowVector", false); - return doRowWise(rowVector, 'h'); - } - - @Override - public INDArray rsubRowVector(INDArray rowVector) { - validateNumericalArray("rsubRowVector", false); - return dup().rsubiRowVector(rowVector); - } - - @Override - public INDArray put(int i, INDArray element) { - Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element); - return putScalar(i, element.getDouble(0)); - } - - @Override - public INDArray diviColumnVector(INDArray columnVector) { - validateNumericalArray("diviColumnVector", false); - return doColumnWise(columnVector, 'd'); - } - - @Override - public INDArray divColumnVector(INDArray columnVector) { - validateNumericalArray("divColumnVector", false); - return dup().diviColumnVector(columnVector); - } - - @Override - public INDArray diviRowVector(INDArray rowVector) { - validateNumericalArray("diviRowVector", false); - return doRowWise(rowVector, 'd'); - } - - @Override - public INDArray divRowVector(INDArray rowVector) { - validateNumericalArray("divRowVector", false); - return dup().diviRowVector(rowVector); - } - - @Override - public INDArray muliColumnVector(INDArray columnVector) { - validateNumericalArray("muliColumnVector", false); - return doColumnWise(columnVector, 'm'); - } - - @Override - public INDArray mulColumnVector(INDArray columnVector) { - validateNumericalArray("mulColumnVector", false); - return dup().muliColumnVector(columnVector); - } - - @Override - public INDArray muliRowVector(INDArray rowVector) { - validateNumericalArray("muliRowVector", false); - return doRowWise(rowVector, 'm'); - } - - @Override - public INDArray mulRowVector(INDArray rowVector) { - validateNumericalArray("mulRowVector", false); - return dup().muliRowVector(rowVector); - } - - @Override - public INDArray subiColumnVector(INDArray columnVector) { - validateNumericalArray("subiColumnVector", false); - return doColumnWise(columnVector, 's'); - } - - @Override - public INDArray subColumnVector(INDArray columnVector) { - validateNumericalArray("subColumnVector", false); - return dup().subiColumnVector(columnVector); - } - - @Override - public INDArray subiRowVector(INDArray rowVector) { - validateNumericalArray("subiRowVector", false); - return doRowWise(rowVector, 's'); - } - - @Override - public INDArray subRowVector(INDArray rowVector) { - validateNumericalArray("subRowVector", false); - return dup().subiRowVector(rowVector); - } - - @Override - public INDArray addiColumnVector(INDArray columnVector) { - validateNumericalArray("addiColumnVector", false); - return doColumnWise(columnVector, 'a'); - } - - @Override - public INDArray putiColumnVector(INDArray columnVector) { - return doColumnWise(columnVector, 'p'); - } - - @Override - public INDArray addColumnVector(INDArray columnVector) { - validateNumericalArray("addColumnVector", false); - return dup().addiColumnVector(columnVector); - } - - @Override - public INDArray addiRowVector(INDArray rowVector) { - validateNumericalArray("addiRowVector", false); - return doRowWise(rowVector, 'a'); - } - - @Override - public INDArray putiRowVector(INDArray rowVector) { - validateNumericalArray("putiRowVector", false); - return doRowWise(rowVector, 'p'); - } - - @Override - public INDArray addRowVector(INDArray rowVector) { - validateNumericalArray("addRowVector", false); - return dup().addiRowVector(rowVector); - } - - @Override - public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { - return mMulTranspose.exec(this, other, result); - } - - @Override - public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { - return mMulTranspose.exec(this, other, null); - } - - @Override - public INDArray mmul(INDArray other, char resultOrder) { - Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given"); - Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); - // FIXME: add support for 3D+ here? - long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()}; - INDArray result = createUninitialized(this.dataType(), shape, resultOrder); - if (result.isScalar()) - return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1); - return mmuli(other, result); - } - - @Override - public INDArray mmul(INDArray other) { - return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c'); - } - - protected INDArray create(int[] shape, char ordering) { - return Nd4j.create(shape, ordering); - } - - @Override - public double[][] toDoubleMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); + @Override + public int getLeadingOnes() { + int numLeadingOnes = 0; + for (int i = 0; i < rank(); i++) { + if (size(i) == 1) { + numLeadingOnes++; } - - if (this.size(0) > Integer.MAX_VALUE || this.size(1) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - double[][] ret = new double[rows()][columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asDouble(); - } - - return ret; } - @Override - public double[] toDoubleVector() { - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - return dup().data().asDouble(); - } + return numLeadingOnes; + } - @Override - public float[] toFloatVector() { - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - return dup().data().asFloat(); - } + @Override + public INDArray slice(long slice, int dimension) { + Nd4j.getCompressor().autoDecompress(this); - @Override - public float[][] toFloatMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); - } + long slices = size(dimension); + if (slice >= slices) { + throw new IllegalArgumentException("Illegal slice " + slice); + } - if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - float[][] ret = new float[(int) rows()][ (int) columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asFloat(); - } - - return ret; - } - - @Override - public int[] toIntVector() { - if (isEmpty()) - return new int[0]; - - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - if(isView() || elementWiseStride() != 1){ - return dup().data().asInt(); - } - return data().asInt(); - } - - @Override - public long[] toLongVector() { - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - if(isView() || elementWiseStride() != 1){ - return dup().data().asLong(); - } - return data().asLong(); - } - - @Override - public long[][] toLongMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); - } - - if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - long[][] ret = new long[(int) rows()][(int) columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asLong(); - } - - return ret; - } - - @Override - public int[][] toIntMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); - } - - if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - int[][] ret = new int[(int) rows()][(int) columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asInt(); - } - - return ret; - } - - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @return the result of the matrix multiplication - */ - @Override - public INDArray mmul(INDArray other, INDArray result) { - return mmuli(other, result); - } - - @Override - public INDArray div(INDArray other) { - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + if (jvmShapeInfo.rank == 0) { + if (slice == 0) { + return createScalarForIndex(slice, true); } else { - return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); - } - } - - @Override - public INDArray div(INDArray other, INDArray result) { - validateNumericalArray("div", true); - return divi(other, result); - } - - @Override - public INDArray mul(INDArray other) { - validateNumericalArray("mul", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return muli(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - val z = Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()); - return muli(other, z); - } - } - - @Override - public INDArray mul(INDArray other, INDArray result) { - return muli(other, result); - } - - @Override - public INDArray sub(INDArray other) { - validateNumericalArray("sub", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); - } - } - - @Override - public INDArray sub(INDArray other, INDArray result) { - return subi(other, result); - } - - @Override - public INDArray add(INDArray other) { - validateNumericalArray("add", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); - } - } - - @Override - public INDArray add(INDArray other, INDArray result) { - validateNumericalArray("add", false); - return addi(other, result); - } - - @Override - public INDArray mmuli(INDArray other, MMulTranspose transpose) { - validateNumericalArray("mmuli", false); - return dup().mmuli(other, this,transpose); - } - - @Override - public INDArray mmuli(INDArray other) { - validateNumericalArray("mmuli", false); - return dup().mmuli(other, this); - } - - @Override - public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { - return transpose.exec(this, other, result); - } - - @Override - public INDArray mmuli(INDArray other, INDArray result) { - validateNumericalArray("mmuli", false); - LinAlgExceptions.assertMultiplies(this, other); - if(other.rank() == 1){ - //GEMV edge case - Preconditions.checkState(result.length() == this.size(0) && this.size(1) == other.size(0), - "Invalid matrix multiplication: %ndShape x %ndShape with result shape %ndShape", this, other, result); - } else { - //Standard case - Preconditions.checkState( - result.rank() == 2 && result.size(0) == this.size(0) && result.size(1) == other.size(1), - "Invalid result array shape: expected shape [%s,%s], got shape %ndShape result array for %ndShape x %ndShape", this.size(0), other.size(1), result, - this, other); - } - - if (other.isScalar()) { - return muli(other.getDouble(0), result); - } - if (isScalar()) { - return other.muli(getDouble(0), result); - } - - /* check sizes and resize if necessary */ - - - if (result == this || result == other) { - /* actually, blas cannot do multiplications in-place. Therefore, we will fake by - * allocating a temporary object on the side and copy the result later. - */ - INDArray temp = Nd4j.create(result.dataType(), result.shape(), Nd4j.getStrides(result.shape(), 'f'), 'f'); - - if (other.columns() == 1 || other.rank() == 1) { - Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(result), - BlasBufferUtil.getCharForTranspose(this), 1.0, this, other, 0.0, temp); - } - - else { - Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result), - BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0, - this, other, 0.0, temp); - } - - result.assign(temp); - - - } else { - - //We require that the result array is 'f' (fortran) order - // However, user might have called mmuli with a c order array for the result - // In which case, we need to allocate a temporary f order array, and later do an assign to the real result array - - boolean requiresTemp = result.ordering() != 'f' || result.isView() || !Shape.hasDefaultStridesForShape(result); - INDArray gemmResultArr; - if (requiresTemp) { - //Can use createUninitialized due to beta==0.0 parameter in gemm - gemmResultArr = Nd4j.createUninitialized(result.dataType(), result.shape(), 'f'); - } else { - gemmResultArr = result; - } - - if (other.columns() == 1 || other.rank() == 1) { - Nd4j.getBlasWrapper().level2().gemv( - ordering(), - BlasBufferUtil.getCharForTranspose(other), - 1.0, - this, - other, - 0.0, - gemmResultArr); - } else { - //gemm doesn't support strides so vectors and views - //don't work - Nd4j.getBlasWrapper().level3().gemm(ordering(), - BlasBufferUtil.getCharForTranspose(other), - BlasBufferUtil.getCharForTranspose(gemmResultArr), - 1.0, - this, - other, - 0.0, - gemmResultArr); - } - - if (requiresTemp) { - result.assign(gemmResultArr); - } - } - - // 1D edge case: reshape back to vector - if (other.rank() == 1) - result = result.reshape(result.length()); - return result; - } - - private INDArray create(int[] shape, int[] stride) { - return Nd4j.create(shape, stride); - } - - @Override - public INDArray divi(INDArray other) { - return divi(other, this); - } - - @Override - public INDArray divi(INDArray other, INDArray result) { - validateNumericalArray("divi", false); - Shape.assertBroadcastable("divi", this, other, result); - Nd4j.exec(new DivOp(this, other, result)); - return result; - } - - @Override - public INDArray muli(INDArray other) { - return muli(other, this); - } - - @Override - public INDArray muli(INDArray other, INDArray result) { - validateNumericalArray("muli", false); - Shape.assertBroadcastable("muli", this, other, result); - Nd4j.exec(new MulOp(this, other, result)); - return result; - } - - @Override - public INDArray subi(INDArray other) { - return subi(other, this); - } - - /** - * in place subtraction of two matrices - * - * @param other the second ndarray to subtract - * @param result the result ndarray - * @return the result of the subtraction - */ - @Override - public INDArray subi(INDArray other, INDArray result) { - validateNumericalArray("subi", false); - Shape.assertBroadcastable("subi", this, other, result); - Nd4j.exec(new SubOp(this, other, result)); - return result; - } - - @Override - public INDArray addi(INDArray other) { - return addi(other, this); - } - - @Override - public INDArray addi(INDArray other, INDArray result) { - validateNumericalArray("addi", false); - Shape.assertBroadcastable("addi", this, other, result); - Nd4j.exec(new AddOp(this, other, result)); - return result; - } - - @Override - public INDArray normmax(boolean keepDims, int... dimension) { - validateNumericalArray("normmax", false); - return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); - } - - @Override - public INDArray normmax(int... dimension) { - return normmax(false, dimension); - } - - @Override - public INDArray rdiv(INDArray other) { - validateNumericalArray("rdiv", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return rdivi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return rdivi(other, this.ulike()); - } - } - - @Override - public INDArray rdivi(INDArray other) { - return rdivi(other, this); - } - - @Override - public INDArray rdiv(INDArray other, INDArray result) { - validateNumericalArray("rdiv", false); - return dup().rdivi(other, result); - } - - @Override - public INDArray rdivi(INDArray other, INDArray result) { - validateNumericalArray("rdivi", false); - Shape.assertBroadcastable("rdivi", this, other, result); - Nd4j.exec(new RDivOp(this, other, result)); - return result; - } - - @Override - public INDArray rsub(INDArray other, INDArray result) { - validateNumericalArray("rsub", false); - return rsubi(other, result); - } - - @Override - public INDArray rsub(INDArray other) { - validateNumericalArray("rsub", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return rsubi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return rsubi(other, this.ulike()); - } - } - - @Override - public INDArray rsubi(INDArray other) { - return rsubi(other, this); - } - - @Override - public INDArray rsubi(INDArray other, INDArray result) { - validateNumericalArray("rsubi", false); - Shape.assertBroadcastable("rsubi", this, other, result); - Nd4j.exec(new RSubOp(this, other, result)); - return result; - } - - @Override - public INDArray assign(Number value) { - Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " + - "assign on boolean arrays: got value %s on to assign to boolean array with shape %ndShape", value, this); - Nd4j.getExecutioner().exec(new ScalarSet(this, value)); - return this; - } - - @Override - public INDArray assign(boolean value) { - return assign(value ? 1 : 0); - } - - @Override - public INDArray assignIf(INDArray arr, Condition condition) { - BooleanIndexing.assignIf(this, arr, condition); - return this; - } - - @Override - public INDArray replaceWhere(INDArray arr, Condition condition) { - Nd4j.getCompressor().autoDecompress(this); - BooleanIndexing.replaceWhere(this, arr, condition); - return this; - } - - @Override - @Deprecated //TODO: Investigate. Not deprecated in the base interface. - public long linearIndex(long i) { - long idx = i; - for (int j = 0; j < jvmShapeInfo.rank - 1; j++) { - if (size((int) i) == 1) - continue; - idx += i * stride(j); - } - return Shape.offset(jvmShapeInfo.javaShapeInformation) + (idx); - } - - @Override - public INDArray slice(long slice) { - Nd4j.getCompressor().autoDecompress(this); - - - long slices = slices(); - if (slice >= slices) - throw new IllegalArgumentException("Illegal slice " + slice); - - if (jvmShapeInfo.rank == 0 ) { throw new IllegalArgumentException("Can't slice a 0-d NDArray"); } + } - if (slice < 0) - slice += rank(); - INDArrayIndex[] indexes = new INDArrayIndex[rank()]; - indexes[0] = NDArrayIndex.point(slice); - for (int i = 1; i < rank(); i++) { + if (slice < 0) { + slice += rank(); + } + INDArrayIndex[] indexes = new INDArrayIndex[rank()]; + indexes[dimension] = NDArrayIndex.point(slice); + for (int i = 0; i < rank(); i++) { + if (i != dimension) { indexes[i] = NDArrayIndex.all(); } - return get(indexes); } + return get(indexes); + } + @Override + public INDArray getScalar(int[] indexes) { + if (indexes.length > rank()) { + throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); + } - protected INDArray createScalarForIndex(long i, boolean applyOffset) { - if(isVector()) - return getScalar(i); - return Nd4j.create(data(), new long[] {1, 1}, new long[] {1, 1}, i); - } - - protected INDArray createScalar(double d) { - return Nd4j.scalar(d); - } - - @Override - public int getTrailingOnes() { - int numLeadingOnes = 0; - for (int i = rank() - 1; i > 0; i--) { - if (size(i) == 1) - numLeadingOnes++; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += this.size(i); } - - return numLeadingOnes; } + long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + val buffer = Nd4j.createBuffer(this.data(), idx, 1); + val shape = Nd4j.getShapeInfoProvider() + .createShapeInformation(new long[0], new long[0], 1, 'c', this.dataType(), false); + return Nd4j.createArrayFromShapeBuffer(buffer, shape); + } - @Override - public int getLeadingOnes() { - int numLeadingOnes = 0; - for (int i = 0; i < rank(); i++) { - if (size(i) == 1) - numLeadingOnes++; - } + @Override + public INDArray getScalar(long... indexes) { + if (indexes.length > rank()) { + throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); + } - return numLeadingOnes; - } - - @Override - public INDArray slice(long slice, int dimension) { - Nd4j.getCompressor().autoDecompress(this); - - long slices = size(dimension); - if (slice >= slices) - throw new IllegalArgumentException("Illegal slice " + slice); - - if (jvmShapeInfo.rank == 0) { - if (slice == 0) - return createScalarForIndex(slice, true); - else - throw new IllegalArgumentException("Can't slice a 0-d NDArray"); - - } - - - if (slice < 0) - slice += rank(); - INDArrayIndex[] indexes = new INDArrayIndex[rank()]; - indexes[dimension] = NDArrayIndex.point(slice); - for (int i = 0; i < rank(); i++) { - if (i != dimension) - indexes[i] = NDArrayIndex.all(); - } - return get(indexes); - - } - - @Override - public INDArray getScalar(int[] indexes) { - if (indexes.length > rank()) - throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += this.size(i); - } - long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - val buffer = Nd4j.createBuffer(this.data(), idx, 1); - val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1, 'c', this.dataType(), false); - return Nd4j.createArrayFromShapeBuffer(buffer, shape); - } - - @Override - public INDArray getScalar(long... indexes) { - if (indexes.length > rank()) - throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += this.size(i); - } - - long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - val buffer = Nd4j.createBuffer(this.data(), idx, 1); - val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', this.dataType(), false); - return Nd4j.createArrayFromShapeBuffer(buffer, shape); - } - - @Override - public INDArray rdiv(Number n) { - //return dup().rdivi(n); - return rdivi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering())); - } - - @Override - public INDArray rdivi(Number n) { - return rdivi(n, this); - } - - @Override - public INDArray rsub(Number n) { - validateNumericalArray("rsub", false); - return rsubi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); - } - - @Override - public INDArray rsubi(Number n) { - return rsubi(n, this); - } - - @Override - public INDArray div(Number n) { - validateNumericalArray("div", false); - return divi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); - } - - @Override - public INDArray divi(Number n) { - return divi(n, this); - } - - @Override - public INDArray mul(Number n) { - validateNumericalArray("mul", false); - return muli(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering())); - } - - @Override - public INDArray muli(Number n) { - return muli(n, this); - } - - @Override - public INDArray sub(Number n) { - validateNumericalArray("sub", false); - return subi(n, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())); - } - - @Override - public INDArray subi(Number n) { - return subi(n, this); - } - - @Override - public INDArray add(Number n) { - validateNumericalArray("add", false); - return addi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); - } - - @Override - public INDArray addi(Number n) { - return addi(n, this); - } - - @Override - public INDArray repmat(long[] shape) { - Nd4j.getCompressor().autoDecompress(this); - long rows = rows() * shape[0]; - long cols = columns() * shape[1]; - INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()).repeat(0, shape[1]); - return ret.reshape(rows, cols); - } - - @Deprecated - @Override - public INDArray repmat(int[] shape) { - long[] longShape = ArrayUtil.toLongArray(shape); - return repmat(longShape); - } - - @Override - public INDArray repeat(int dimension, long... repeats) { - Nd4j.getCompressor().autoDecompress(this); - CustomOp op = DynamicCustomOp.builder("repeat") - .addInputs(this) - .addIntegerArguments(ArrayUtil.toInts(repeats)) //TODO int cast - .build(); - op.addIArgument(dimension); //Native op: last iarg is dimension - - LongShapeDescriptor l = op.calculateOutputShape().get(0); - INDArray out = Nd4j.create(l); - op.addOutputArgument(out); - Nd4j.exec(op); - return out; - } - - @Override - public INDArray putRow(long row, INDArray toPut) { - if (isRowVector() && toPut.isVector()) { - return assign(toPut); - } - return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut); - } - - @Override - public INDArray putColumn(int column, INDArray toPut) { - Nd4j.getCompressor().autoDecompress(this); - - if (isColumnVector() && toPut.isVector()) { - return assign(toPut); - } - return put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut); - } - - @Override - public Number getNumber(long i){ - switch (dataType()){ - case DOUBLE: - case FLOAT: - case HALF: - case BFLOAT16: - return getDouble(i); - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case BOOL: - case UINT64: - case UINT32: - case UINT16: - return getLong(i); - case UTF8: - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType()); + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += this.size(i); } } - @Override - public Number getNumber(long... idx){ - switch (dataType()){ - case DOUBLE: - case FLOAT: - case HALF: - return getDouble(idx); - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case BOOL: - return getLong(idx); - case UTF8: - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType()); - } + long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + val buffer = Nd4j.createBuffer(this.data(), idx, 1); + val shape = Nd4j.getShapeInfoProvider() + .createShapeInformation(new long[0], new long[0], 1, 'c', this.dataType(), false); + return Nd4j.createArrayFromShapeBuffer(buffer, shape); + } + + @Override + public INDArray rdiv(Number n) { + //return dup().rdivi(n); + return rdivi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray rdivi(Number n) { + return rdivi(n, this); + } + + @Override + public INDArray rsub(Number n) { + validateNumericalArray("rsub", false); + return rsubi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray rsubi(Number n) { + return rsubi(n, this); + } + + @Override + public INDArray div(Number n) { + validateNumericalArray("div", false); + return divi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray divi(Number n) { + return divi(n, this); + } + + @Override + public INDArray mul(Number n) { + validateNumericalArray("mul", false); + return muli(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray muli(Number n) { + return muli(n, this); + } + + @Override + public INDArray sub(Number n) { + validateNumericalArray("sub", false); + return subi(n, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())); + } + + @Override + public INDArray subi(Number n) { + return subi(n, this); + } + + @Override + public INDArray add(Number n) { + validateNumericalArray("add", false); + return addi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray addi(Number n) { + return addi(n, this); + } + + @Override + public INDArray repmat(long[] shape) { + Nd4j.getCompressor().autoDecompress(this); + long rows = rows() * shape[0]; + long cols = columns() * shape[1]; + INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()) + .repeat(0, shape[1]); + return ret.reshape(rows, cols); + } + + @Deprecated + @Override + public INDArray repmat(int[] shape) { + long[] longShape = ArrayUtil.toLongArray(shape); + return repmat(longShape); + } + + @Override + public INDArray repeat(int dimension, long... repeats) { + Nd4j.getCompressor().autoDecompress(this); + CustomOp op = DynamicCustomOp.builder("repeat") + .addInputs(this) + .addIntegerArguments(ArrayUtil.toInts(repeats)) //TODO int cast + .build(); + op.addIArgument(dimension); //Native op: last iarg is dimension + + LongShapeDescriptor l = op.calculateOutputShape().get(0); + INDArray out = Nd4j.create(l); + op.addOutputArgument(out); + Nd4j.exec(op); + return out; + } + + @Override + public INDArray putRow(long row, INDArray toPut) { + if (isRowVector() && toPut.isVector()) { + return assign(toPut); + } + return put(new INDArrayIndex[]{NDArrayIndex.point(row), NDArrayIndex.all()}, toPut); + } + + @Override + public INDArray putColumn(int column, INDArray toPut) { + Nd4j.getCompressor().autoDecompress(this); + + if (isColumnVector() && toPut.isVector()) { + return assign(toPut); + } + return put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut); + } + + @Override + public Number getNumber(long i) { + switch (dataType()) { + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + return getDouble(i); + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case BOOL: + case UINT64: + case UINT32: + case UINT16: + return getLong(i); + case UTF8: + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException( + "Cannot get number from array of datatype: " + dataType()); + } + } + + @Override + public Number getNumber(long... idx) { + switch (dataType()) { + case DOUBLE: + case FLOAT: + case HALF: + return getDouble(idx); + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case BOOL: + return getLong(idx); + case UTF8: + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException( + "Cannot get number from array of datatype: " + dataType()); + } + } + + @Override + public double getDouble(long i) { + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + if (i >= length()) { + throw new IllegalArgumentException( + "Unable to get linear index " + i + ": values is greater than length (" + length() + ")"); } - @Override - public double getDouble(long i) { - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + autoProcessScalarCall(); - if (i >= length()) { - throw new IllegalArgumentException("Unable to get linear index " + i + ": values is greater than length (" + length() + ")"); - } + if (i == 0) { + return data().getDouble(i); + } - autoProcessScalarCall(); + long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); + Shape.assertShapeLessThan(dimensions, shape()); + return getDouble(dimensions); - if (i == 0) - return data().getDouble(i); + } - long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); - Shape.assertShapeLessThan(dimensions, shape()); - return getDouble(dimensions); + @Override + public double getDouble(long i, long j) { + return getDouble(new long[]{i, j}); + } + @Override + public float getFloat(long i) { + return (float) getDouble(i); + } + + @Override + public float getFloat(long i, long j) { + return (float) getDouble(i, j); + } + + @Override + public INDArray transpose() { + Preconditions.checkState(rank() >= 2, + "Can't transpose array with rank < 2: array shape %ndShape", this); + + return permute(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); + } + + /** + * Return transposed version of this matrix. + *

      + * PLEASE NOTE: This method is NOT in place, it will return transposed copy instead. + */ + @Override + public INDArray transposei() { + Preconditions.checkState(rank() >= 2, + "Can't transpose array with rank < 2: array shape %ndShape", this); + + return permutei(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); + } + + protected INDArray create(DataBuffer data, int[] shape, int[] strides) { + return Nd4j.create(data, shape, strides, 0, ordering()); + } + + @Deprecated + @Override + public INDArray reshape(char order, int... newShape) { + return reshape(order, ArrayUtil.toLongArray(newShape)); + } + + @Override + public INDArray reshape(char order, long... newShape) { + return reshape(order, false, newShape); + } + + @Override + public INDArray reshape(char order, boolean enforceView, long... newShape) { + Nd4j.getCompressor().autoDecompress(this); + + // special case for empty reshape + if (this.length() == 1 && (newShape == null || newShape.length == 0) + && this.elementWiseStride() == 1) { + return Nd4j.create(this.data(), new int[0], new int[0], 0); } - @Override - public double getDouble(long i, long j) { - return getDouble(new long[] {i, j}); - } - - @Override - public float getFloat(long i) { - return (float) getDouble(i); - } - - @Override - public float getFloat(long i, long j) { - return (float) getDouble(i, j); - } - - @Override - public INDArray transpose() { - Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this); - - return permute(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); - } - - /** - * - * Return transposed version of this matrix. - * - * PLEASE NOTE: This method is NOT in place, it will return transposed copy instead. - */ - @Override - public INDArray transposei() { - Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this); - - return permutei(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); - } - - protected INDArray create(DataBuffer data, int[] shape, int[] strides) { - return Nd4j.create(data, shape, strides, 0, ordering()); - } - - @Deprecated - @Override - public INDArray reshape(char order, int... newShape) { - return reshape(order, ArrayUtil.toLongArray(newShape)); - } - - @Override - public INDArray reshape(char order, long... newShape) { - return reshape(order, false, newShape); - } - - @Override - public INDArray reshape(char order, boolean enforceView, long... newShape){ - Nd4j.getCompressor().autoDecompress(this); - - // special case for empty reshape - if (this.length() == 1 && (newShape == null || newShape.length == 0) && this.elementWiseStride() == 1) { - return Nd4j.create(this.data(), new int[0], new int[0], 0); - } - - if (newShape == null || newShape.length < 1) - throw new ND4JIllegalStateException( - "Can't reshape(long...) without shape arguments. Got empty shape instead."); - - // TODO: maybe toFlatten() makes more sense here? - // reshape(-1) special case - if (newShape.length == 1 && newShape[0] == -1) - newShape[0] = this.length(); - - int numberNegativesOnes = 0; - long[] shape = ArrayUtil.copy(newShape); - - - for (int i = 0; i < shape.length; i++) { - if (shape[i] < 0) { - if (numberNegativesOnes >= 1) - throw new IllegalArgumentException("Only one dimension can be negative ones. Got shape " - + Arrays.toString(newShape)); - - numberNegativesOnes++; - - int shapeLength = 1; - for (int j = 0; j < shape.length; j++) - if (shape[j] >= 1) - shapeLength *= shape[j]; - long realShape = Math.abs(length() / shapeLength); - long[] thisNewShape = new long[shape.length]; - for (int j = 0; j < shape.length; j++) { - if (i != j) { - thisNewShape[j] = shape[j]; - } else - thisNewShape[j] = realShape; - } - - shape = thisNewShape; - break; - - } - } - - long prod = ArrayUtil.prodLong(shape); - - if (prod != this.length()) { - throw new ND4JIllegalStateException("New shape length doesn't match original length: [" + prod + "] vs [" + this.length() + "]. Original shape: "+Arrays.toString(this.shape())+" New Shape: "+Arrays.toString(newShape)); - } - - - - - - INDArray reshapeAttempt = Shape.newShapeNoCopy(this, shape, order == 'f'); - if (reshapeAttempt != null) { - // kinda strange get/set usage - // reshapeAttempt.setOrder(Shape.getOrder(reshapeAttempt)); - return reshapeAttempt; - } - - if(enforceView){ - throw new ND4JIllegalStateException("Unable to reshape array as view, called with enforceView=true. " + - "Use enforceView=false to return a copy instead, or call reshape on a non-strided array. Array shape info: " + this.shapeInfoToString().replaceAll("\n","")); - } - - - if (order != ordering()) { - INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order); - ret.setData(dup(order).data()); - return ret; - } else if (this.isEmpty()) { - return Nd4j.create(this.dataType(), shape); - } else { - INDArray ret = this.dup(order); - return Nd4j.create(ret.data(), shape); - } - } - - @Override - public double getDoubleUnsafe(long offset) { - return data().getDouble(offset); - } - - @Override - public INDArray putScalarUnsafe(long offset, double value) { - autoProcessScalarCall(); - data().put(offset, value); - return this; - } - - @Override - public INDArray reshape(char order, int rows, int columns) { - return reshape(order, new long[] {rows, columns}); - } - - /** - * Reshape the ndarray in to the specified dimensions, - * possible errors being thrown for invalid shapes - * - * Note here that one dimension can be -1. - * The dimension that is -1 will be inferred from the shape and - * the length of the ndarray - * - * @param shape the shape of the ndarray. - * @return the new reshaped nd array - */ - - @Override - public INDArray reshape(int[] shape) { - return reshape(Nd4j.order(), shape); - } - - @Override - public INDArray reshape(long... shape) { - return reshape(Nd4j.order(), shape); - } - - @Override - public INDArray prod(boolean keepDims, int... dimension) { - validateNumericalArray("prod", false); - return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); - } - - @Override - public INDArray prod(int... dimension) { - return prod(false, dimension); - } - - @Override - public INDArray mean(boolean keepDims, int... dimension) { - validateNumericalArray("mean", false); - return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); - } - - @Override - public INDArray mean(int... dimension) { - return mean(false, dimension); - } - - @Override - public INDArray amean(int... dimension) { - validateNumericalArray("amean", false); - return Nd4j.getExecutioner().exec(new AMean(this, dimension)); - } - - @Override - public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) { - validateNumericalArray("mean", false); - return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension)); - } - - @Override - public INDArray mean(@NonNull INDArray result, int... dimension) { - return mean(result, false, dimension); - } - - @Override - public INDArray var(int... dimension) { - validateNumericalArray("var", false); - return Nd4j.getExecutioner().exec(new Variance(this, dimension)); - } - - @Override - public INDArray var(boolean biasCorrected, int... dimension) { - validateNumericalArray("var", false); - return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); - } - - @Override - public INDArray max(boolean keepDims, int... dimension) { - validateNumericalArray("max", false); - return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); - } - - @Override - public INDArray max(int... dimension) { - return max(false, dimension); - } - - @Override - public INDArray amax(int... dimension) { - validateNumericalArray("amax", false); - return Nd4j.getExecutioner().exec(new AMax(this, dimension)); - } - - @Override - public INDArray min(boolean keepDims, int... dimension) { - validateNumericalArray("min", false); - return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); - } - - @Override - public INDArray min(int... dimension) { - return min(false, dimension); - } - - @Override - public INDArray amin(int... dimension) { - validateNumericalArray("amin", false); - return Nd4j.getExecutioner().exec(new AMin(this, dimension)); - } - - @Override - public INDArray sum(int... dimension) { - validateNumericalArray("sum", true); - return Nd4j.getExecutioner().exec(new Sum(this, dimension)); - } - - @Override - public INDArray sum(boolean keepDim, int... dimension) { - validateNumericalArray("sum", true); - return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); - } - - @Override - public INDArray entropy(int... dimension) { - validateNumericalArray("entropy", false); - return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); - } - - @Override - public INDArray shannonEntropy(int... dimension) { - validateNumericalArray("shannonEntropy", false); - return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); - } - - @Override - public INDArray logEntropy(int... dimension) { - validateNumericalArray("logEntropy", false); - return Nd4j.getExecutioner().exec(new LogEntropy(this, dimension)); - } - - @Override - public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) { - validateNumericalArray("sum", true); - return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension)); - } - - @Override - public INDArray sum(@NonNull INDArray result, int... dimension) { - return sum(result, false, dimension); - } - - @Override - public INDArray norm1(int... dimension) { - return norm1(false, dimension); - } - - @Override - public INDArray norm1(boolean keepDims, int... dimension) { - validateNumericalArray("norm1", false); - return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); - } - - @Override - public INDArray std(int... dimension) { - return std(true, dimension); - } - - @Override - public INDArray std(boolean biasCorrected, int... dimension) { - return std(biasCorrected, false, dimension); - } - - @Override - public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) { - validateNumericalArray("std", false); - return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, keepDims, dimension)); - } - - @Override - public Number stdNumber(boolean biasCorrected) { - validateNumericalArray("stdNumber", false); - return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); - } - - @Override - public INDArray norm2(boolean keepDims, int... dimension) { - validateNumericalArray("norm2", false); - return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); - } - - @Override - public INDArray norm2(int... dimension) { - return norm2(false, dimension); - } - - @Override - public int columns() { - if (isMatrix()) - return (int) size(1); - else if (Shape.isColumnVectorShape(shape())) { - return 1; - } else if (Shape.isRowVectorShape(shape())) { - return (int) length(); - } - throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); - - - } - - @Override - public int rows() { - if (isMatrix()) - return (int) size(0); - else if (Shape.isRowVectorShape(shape())) { - return 1; - } else if (Shape.isColumnVectorShape(shape())) { - return (int) length(); - } - - throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); - } - - @Override - public INDArray ravel(char ordering) { - Nd4j.getCompressor().autoDecompress(this); - if(ordering == this.ordering() && Shape.hasDefaultStridesForShape(this)){ - return reshape(ordering, length()); - } - return dup(ordering).reshape(ordering, length()); - } - - @Override - public INDArray ravel() { - return reshape(length()); - } - - @Override - public void sliceVectors(List list) { - if (isVector()) - list.add(this); - else { - for (int i = 0; i < slices(); i++) { - slice(i).sliceVectors(list); - } - } - } - - @Override - public INDArray reshape(long newRows, long newColumns) { - return reshape(new long[] {newRows, newColumns}); - } - - @Override - public INDArray getColumn(long c) { - Nd4j.getCompressor().autoDecompress(this); - - if (isColumnVector() && c == 0) - return this; - else if (isColumnVector() && c > 0) - throw new IllegalArgumentException("Illegal index for column"); - Preconditions.checkArgument(this.rank() == 2, "getColumn() can be called on 2D arrays only"); - return tensorAlongDimension(c, 0); - } - - @Override - public INDArray getColumn(long c, boolean keepDim) { - INDArray col = getColumn(c); - if(!keepDim) - return col; - return col.reshape(col.length(), 1); - } - - @Override - public INDArray getRows(int[] rindices) { - Nd4j.getCompressor().autoDecompress(this); - - if (!isMatrix() && !isVector()) - throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); - if (isVector()) - return Nd4j.pullRows(this, 1, rindices); - else { - INDArray ret = Nd4j.createUninitialized(this.dataType(), rindices.length, columns()); - for (int i = 0; i < rindices.length; i++) - ret.putRow(i, getRow(rindices[i])); - return ret; - } - } - - @Override - public INDArray get(INDArrayIndex... indexes) { - Nd4j.getCompressor().autoDecompress(this); - - int numPoint = 0; - int numInterval = 0; - int numAll = 0; - int numNewAxis = 0; - int numSpecified = 0; - for(INDArrayIndex i : indexes){ - if(i instanceof PointIndex){ - numPoint++; - } else if(i instanceof NDArrayIndexAll){ - numAll++; - } else if(i instanceof IntervalIndex){ - numInterval++; - } else if(i instanceof NewAxis){ - numNewAxis++; - } else if(i instanceof SpecifiedIndex){ - numSpecified++; + if (newShape == null || newShape.length < 1) { + throw new ND4JIllegalStateException( + "Can't reshape(long...) without shape arguments. Got empty shape instead."); + } + + // TODO: maybe toFlatten() makes more sense here? + // reshape(-1) special case + if (newShape.length == 1 && newShape[0] == -1) { + newShape[0] = this.length(); + } + + int numberNegativesOnes = 0; + long[] shape = ArrayUtil.copy(newShape); + + for (int i = 0; i < shape.length; i++) { + if (shape[i] < 0) { + if (numberNegativesOnes >= 1) { + throw new IllegalArgumentException( + "Only one dimension can be negative ones. Got shape " + + Arrays.toString(newShape)); + } + + numberNegativesOnes++; + + int shapeLength = 1; + for (int j = 0; j < shape.length; j++) { + if (shape[j] >= 1) { + shapeLength *= shape[j]; + } + } + long realShape = Math.abs(length() / shapeLength); + long[] thisNewShape = new long[shape.length]; + for (int j = 0; j < shape.length; j++) { + if (i != j) { + thisNewShape[j] = shape[j]; } else { - throw new IllegalStateException("Unknown index: " + i); + thisNewShape[j] = realShape; } } - // Padding remaining dimensions with all() index if too few indices provided - if (indexes.length - numNewAxis < this.rank()) { - val newIndexes = new INDArrayIndex[this.rank() + numNewAxis]; - System.arraycopy(indexes, 0, newIndexes, 0, indexes.length); + shape = thisNewShape; + break; - for (int e = indexes.length; e < newIndexes.length; e++) { - numAll++; - newIndexes[e] = NDArrayIndex.all(); - } - - indexes = newIndexes; - } - - Preconditions.checkState((numPoint + numInterval + numAll + numSpecified) == rank(), "Illegal set of indices for array: need at least" + - " %s point/interval/all/specified indices for rank %s array (%ndShape), got indices %s", rank(), rank(), this, indexes); - - int outRank = rank() + numNewAxis - numPoint; - Preconditions.checkState(outRank >= 0, "Illegal set of indices for array: %ndShape, %s", this, indexes); - - - //To work out sub-array, we need to work out 3 things: offset, shape and strides. We calculate all of these - long[] outShape = new long[outRank]; - long[] outStrides = new long[outRank]; - long offset = offset(); //Start with existing offset if view - - int outIdx = 0; //Axis number counter for output array - int inIdx = 0; //Axis number counter for input array - for( int i=0; i= size(inIdx)) { - throw new IllegalStateException("Indices are out of range: Cannot get interval index " + indexes[i] + - " on array with size(" + inIdx + ")=" + size(inIdx) + ". Array shape: " + Arrays.toString(shape()) + - ", indices: " + Arrays.toString(indexes)); - } - long stride = ii.stride(); - long length = (endInc - start)/stride + 1; - - offset += ii.offset() * stride(inIdx); - outShape[outIdx] = length; - outStrides[outIdx] = ii.stride() * stride(inIdx); - inIdx++; - outIdx++; - } else if(indexes[i] instanceof NewAxis) { - //New axis: appends a 1 in shape. Axis not present in input, but is present in output - outShape[outIdx] = 1; - if (outIdx > 0) { //Stride doesn't matter for 1 size axis anyway... - outStrides[outIdx] = outStrides[outIdx - 1]; - } else { - outStrides[outIdx] = 1; - } - outIdx++; - } else if(indexes[i] instanceof SpecifiedIndex){ - //Specified index: axis present in both input and output - SpecifiedIndex si = (SpecifiedIndex)indexes[i]; - outShape[outIdx++] = si.length(); - inIdx++; - //Don't care about strides for specified index, as result won't be a view - } else { - throw new IllegalStateException("Unknown index type: " + i); //Should never happen - } - } - - - //Note: If we have specified indices, we can't return a view. Instead, we copy the specified sub-arrays from - // the input array to the output array. - //How? Create the output array, then do loop over the specified indices only, and copy sub-arrays for all other axes - if (numSpecified > 0) { - INDArray out = Nd4j.create(dataType(), outShape); - - //Need to copy subsets here - long[] specifiedSizes = new long[numSpecified]; - SpecifiedIndex[] si = new SpecifiedIndex[numSpecified]; - int j=0; - for( int i=0; i replace with loop + point - // ii. new axis indices -> ignore/exclude (don't appear in input) - // iii. interval indices -> replace with all - //(2) Get from output: requested indices, except for: - // i. point indices -> ignore/exclude (don't appear in output) - // ii. new axis indices -> replace with point(0) - - - INDArrayIndex[] pointIdxsIn = new INDArrayIndex[indexes.length - numNewAxis]; //Indices for source (this array) - int[] specifiedAxisIn = new int[numSpecified]; - int specCount = 0; - j = 0; - for( int i=0; i + * Note here that one dimension can be -1. The dimension that is -1 will be inferred from the + * shape and the length of the ndarray + * + * @param shape the shape of the ndarray. + * @return the new reshaped nd array + */ + + @Override + public INDArray reshape(int[] shape) { + return reshape(Nd4j.order(), shape); + } + + @Override + public INDArray reshape(long... shape) { + return reshape(Nd4j.order(), shape); + } + + @Override + public INDArray prod(boolean keepDims, int... dimension) { + validateNumericalArray("prod", false); + return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); + } + + @Override + public INDArray prod(int... dimension) { + return prod(false, dimension); + } + + @Override + public INDArray mean(boolean keepDims, int... dimension) { + validateNumericalArray("mean", false); + return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); + } + + @Override + public INDArray mean(int... dimension) { + return mean(false, dimension); + } + + @Override + public INDArray amean(int... dimension) { + validateNumericalArray("amean", false); + return Nd4j.getExecutioner().exec(new AMean(this, dimension)); + } + + @Override + public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) { + validateNumericalArray("mean", false); + return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension)); + } + + @Override + public INDArray mean(@NonNull INDArray result, int... dimension) { + return mean(result, false, dimension); + } + + @Override + public INDArray var(int... dimension) { + validateNumericalArray("var", false); + return Nd4j.getExecutioner().exec(new Variance(this, dimension)); + } + + @Override + public INDArray var(boolean biasCorrected, int... dimension) { + validateNumericalArray("var", false); + return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); + } + + @Override + public INDArray max(boolean keepDims, int... dimension) { + validateNumericalArray("max", false); + return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); + } + + @Override + public INDArray max(int... dimension) { + return max(false, dimension); + } + + @Override + public INDArray amax(int... dimension) { + validateNumericalArray("amax", false); + return Nd4j.getExecutioner().exec(new AMax(this, dimension)); + } + + @Override + public INDArray min(boolean keepDims, int... dimension) { + validateNumericalArray("min", false); + return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); + } + + @Override + public INDArray min(int... dimension) { + return min(false, dimension); + } + + @Override + public INDArray amin(int... dimension) { + validateNumericalArray("amin", false); + return Nd4j.getExecutioner().exec(new AMin(this, dimension)); + } + + @Override + public INDArray sum(int... dimension) { + validateNumericalArray("sum", true); + return Nd4j.getExecutioner().exec(new Sum(this, dimension)); + } + + @Override + public INDArray sum(boolean keepDim, int... dimension) { + validateNumericalArray("sum", true); + return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); + } + + @Override + public INDArray entropy(int... dimension) { + validateNumericalArray("entropy", false); + return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); + } + + @Override + public INDArray shannonEntropy(int... dimension) { + validateNumericalArray("shannonEntropy", false); + return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); + } + + @Override + public INDArray logEntropy(int... dimension) { + validateNumericalArray("logEntropy", false); + return Nd4j.getExecutioner().exec(new LogEntropy(this, dimension)); + } + + @Override + public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) { + validateNumericalArray("sum", true); + return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension)); + } + + @Override + public INDArray sum(@NonNull INDArray result, int... dimension) { + return sum(result, false, dimension); + } + + @Override + public INDArray norm1(int... dimension) { + return norm1(false, dimension); + } + + @Override + public INDArray norm1(boolean keepDims, int... dimension) { + validateNumericalArray("norm1", false); + return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); + } + + @Override + public INDArray std(int... dimension) { + return std(true, dimension); + } + + @Override + public INDArray std(boolean biasCorrected, int... dimension) { + return std(biasCorrected, false, dimension); + } + + @Override + public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) { + validateNumericalArray("std", false); + return Nd4j.getExecutioner() + .exec(new StandardDeviation(this, biasCorrected, keepDims, dimension)); + } + + @Override + public Number stdNumber(boolean biasCorrected) { + validateNumericalArray("stdNumber", false); + return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); + } + + @Override + public INDArray norm2(boolean keepDims, int... dimension) { + validateNumericalArray("norm2", false); + return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); + } + + @Override + public INDArray norm2(int... dimension) { + return norm2(false, dimension); + } + + @Override + public int columns() { + if (isMatrix()) { + return (int) size(1); + } else if (Shape.isColumnVectorShape(shape())) { + return 1; + } else if (Shape.isRowVectorShape(shape())) { + return (int) length(); + } + throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); + + + } + + @Override + public int rows() { + if (isMatrix()) { + return (int) size(0); + } else if (Shape.isRowVectorShape(shape())) { + return 1; + } else if (Shape.isColumnVectorShape(shape())) { + return (int) length(); + } + + throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); + } + + @Override + public INDArray ravel(char ordering) { + Nd4j.getCompressor().autoDecompress(this); + if (ordering == this.ordering() && Shape.hasDefaultStridesForShape(this)) { + return reshape(ordering, length()); + } + return dup(ordering).reshape(ordering, length()); + } + + @Override + public INDArray ravel() { + return reshape(length()); + } + + @Override + public void sliceVectors(List list) { + if (isVector()) { + list.add(this); + } else { + for (int i = 0; i < slices(); i++) { + slice(i).sliceVectors(list); + } + } + } + + @Override + public INDArray reshape(long newRows, long newColumns) { + return reshape(new long[]{newRows, newColumns}); + } + + @Override + public INDArray getColumn(long c) { + Nd4j.getCompressor().autoDecompress(this); + + if (isColumnVector() && c == 0) { + return this; + } else if (isColumnVector() && c > 0) { + throw new IllegalArgumentException("Illegal index for column"); + } + Preconditions.checkArgument(this.rank() == 2, "getColumn() can be called on 2D arrays only"); + return tensorAlongDimension(c, 0); + } + + @Override + public INDArray getColumn(long c, boolean keepDim) { + INDArray col = getColumn(c); + if (!keepDim) { + return col; + } + return col.reshape(col.length(), 1); + } + + @Override + public INDArray getRows(int[] rindices) { + Nd4j.getCompressor().autoDecompress(this); + + if (!isMatrix() && !isVector()) { + throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); + } + if (isVector()) { + return Nd4j.pullRows(this, 1, rindices); + } else { + INDArray ret = Nd4j.createUninitialized(this.dataType(), rindices.length, columns()); + for (int i = 0; i < rindices.length; i++) { + ret.putRow(i, getRow(rindices[i])); + } + return ret; + } + } + + @Override + public INDArray get(INDArrayIndex... indexes) { + Nd4j.getCompressor().autoDecompress(this); + + int numPoint = 0; + int numInterval = 0; + int numAll = 0; + int numNewAxis = 0; + int numSpecified = 0; + for (INDArrayIndex i : indexes) { + if (i instanceof PointIndex) { + numPoint++; + } else if (i instanceof NDArrayIndexAll) { + numAll++; + } else if (i instanceof IntervalIndex) { + numInterval++; + } else if (i instanceof NewAxis) { + numNewAxis++; + } else if (i instanceof SpecifiedIndex) { + numSpecified++; + } else { + throw new IllegalStateException("Unknown index: " + i); + } + } + + // Padding remaining dimensions with all() index if too few indices provided + if (indexes.length - numNewAxis < this.rank()) { + val newIndexes = new INDArrayIndex[this.rank() + numNewAxis]; + System.arraycopy(indexes, 0, newIndexes, 0, indexes.length); + + for (int e = indexes.length; e < newIndexes.length; e++) { + numAll++; + newIndexes[e] = NDArrayIndex.all(); + } + + indexes = newIndexes; + } + + Preconditions.checkState((numPoint + numInterval + numAll + numSpecified) == rank(), + "Illegal set of indices for array: need at least" + + " %s point/interval/all/specified indices for rank %s array (%ndShape), got indices %s", + rank(), rank(), this, indexes); + + int outRank = rank() + numNewAxis - numPoint; + Preconditions.checkState(outRank >= 0, "Illegal set of indices for array: %ndShape, %s", this, + indexes); + + //To work out sub-array, we need to work out 3 things: offset, shape and strides. We calculate all of these + long[] outShape = new long[outRank]; + long[] outStrides = new long[outRank]; + long offset = offset(); //Start with existing offset if view + + int outIdx = 0; //Axis number counter for output array + int inIdx = 0; //Axis number counter for input array + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof PointIndex) { + //Point indexes don't appear in output + PointIndex pi = (PointIndex) indexes[i]; + offset += pi.offset() * stride(inIdx); + inIdx++; + } else if (indexes[i] instanceof NDArrayIndexAll) { + //All index: doesn't change offset. Axis is in both in and output arrays + outShape[outIdx] = size(inIdx); + outStrides[outIdx] = stride(inIdx); + inIdx++; + outIdx++; + } else if (indexes[i] instanceof IntervalIndex) { + //Interval index: Axis is in both in and output arrays, but output might be smaller + IntervalIndex ii = (IntervalIndex) indexes[i]; + long start = ii.offset(); + long endInc = ii.end() - (ii.isInclusive() ? 0 : 1); + if (endInc >= size(inIdx)) { + throw new IllegalStateException( + "Indices are out of range: Cannot get interval index " + indexes[i] + + " on array with size(" + inIdx + ")=" + size(inIdx) + ". Array shape: " + + Arrays.toString(shape()) + + ", indices: " + Arrays.toString(indexes)); + } + long stride = ii.stride(); + long length = (endInc - start) / stride + 1; + + offset += ii.offset() * stride(inIdx); + outShape[outIdx] = length; + outStrides[outIdx] = ii.stride() * stride(inIdx); + inIdx++; + outIdx++; + } else if (indexes[i] instanceof NewAxis) { + //New axis: appends a 1 in shape. Axis not present in input, but is present in output + outShape[outIdx] = 1; + if (outIdx > 0) { //Stride doesn't matter for 1 size axis anyway... + outStrides[outIdx] = outStrides[outIdx - 1]; } else { - INDArray ret = Nd4j.createUninitialized(this.dataType(), rows(), cindices.length); - for (int i = 0; i < cindices.length; i++) - ret.putColumn(i, getColumn(cindices[i])); - return ret; + outStrides[outIdx] = 1; } - + outIdx++; + } else if (indexes[i] instanceof SpecifiedIndex) { + //Specified index: axis present in both input and output + SpecifiedIndex si = (SpecifiedIndex) indexes[i]; + outShape[outIdx++] = si.length(); + inIdx++; + //Don't care about strides for specified index, as result won't be a view + } else { + throw new IllegalStateException("Unknown index type: " + i); //Should never happen + } } - protected INDArray create(int rows, int length) { - return create(new int[] {rows, length}); - } + //Note: If we have specified indices, we can't return a view. Instead, we copy the specified sub-arrays from + // the input array to the output array. + //How? Create the output array, then do loop over the specified indices only, and copy sub-arrays for all other axes + if (numSpecified > 0) { + INDArray out = Nd4j.create(dataType(), outShape); - @Override - public INDArray getRow(long r) { - if (isRowVector() && r == 0) - return this; - else if (isRowVector() && r > 0) - throw new IllegalArgumentException("Illegal index for row: requested row " + r + " but this.size(0)=" + this.size(0)); + //Need to copy subsets here + long[] specifiedSizes = new long[numSpecified]; + SpecifiedIndex[] si = new SpecifiedIndex[numSpecified]; + int j = 0; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof SpecifiedIndex) { + specifiedSizes[j] = indexes[i].length(); + si[j] = (SpecifiedIndex) indexes[i]; + j++; + } + } + NdIndexIterator iter = new NdIndexIterator(specifiedSizes); - Preconditions.checkArgument(rank() == 2, "getRow() can be called on 2D arrays only"); - Preconditions.checkArgument(r < rows(), "Row index must be smaller than total number of rows"); + //What we need to do here: Iterate over sub-arrays for both input and output + //(1) Get from input: requested indices, except for: + // i. specified indices -> replace with loop + point + // ii. new axis indices -> ignore/exclude (don't appear in input) + // iii. interval indices -> replace with all + //(2) Get from output: requested indices, except for: + // i. point indices -> ignore/exclude (don't appear in output) + // ii. new axis indices -> replace with point(0) - return tensorAlongDimension(r, 1); - } + INDArrayIndex[] pointIdxsIn = new INDArrayIndex[indexes.length + - numNewAxis]; //Indices for source (this array) + int[] specifiedAxisIn = new int[numSpecified]; + int specCount = 0; + j = 0; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof NewAxis) { + continue; //Skip new axis in source dims + } + if (indexes[i] instanceof SpecifiedIndex) { + specifiedAxisIn[specCount++] = j; + } + pointIdxsIn[j++] = indexes[i]; + } - @Override - public INDArray getRow(long r, boolean keepDim) { - INDArray row = getRow(r); - if(!keepDim) - return row; - return row.reshape(1, row.length()); - } + INDArrayIndex[] pointIdxsOut = new INDArrayIndex[indexes.length + - numPoint]; //Indices for destination (output array) + j = 0; + specCount = 0; + int[] specifiedAxisOut = new int[numSpecified]; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof NewAxis) { + pointIdxsOut[j++] = NDArrayIndex.point(0); + continue; + } else if (indexes[i] instanceof PointIndex) { + continue; + } else if (indexes[i] instanceof SpecifiedIndex) { + specifiedAxisOut[specCount++] = j; + } else if (indexes[i] instanceof IntervalIndex) { + pointIdxsOut[j++] = NDArrayIndex.all(); + continue; + } + pointIdxsOut[j++] = indexes[i]; + } - public boolean equalsWithEps(Object o, double eps) { - Nd4j.getCompressor().autoDecompress(this); - - - if (o == null) - return false; - - if (!(o instanceof INDArray)) - return false; - - INDArray n = (INDArray) o; - Nd4j.getCompressor().autoDecompress(n); - - if (n == this) - return true; - - if (this.rank() != n.rank()) - return false; - - if (this.length() != n.length()) - return false; - - if (this.isEmpty() != n.isEmpty()) - return false; - - if (this.isEmpty() && n.isEmpty()) - return Shape.shapeEquals(this.shape(), n.shape()); - - if (this.dataType() != n.dataType()) - return false; - - // meh - if (this.dataType() == DataType.UTF8 && n.dataType() == DataType.UTF8) { - for (long e = 0; e < this.length(); e++) { - val str1 = this.getString(e); - val str2 = n.getString(e); - - if (!str1.equals(str2)) - return false; - } - - return true; + //Iterate over sub-arrays; copy from source to destination + while (iter.hasNext()) { + long[] specifiedIdxs = iter.next(); + for (int i = 0; i < specifiedIdxs.length; i++) { + long sourceIdx = si[i].getIndexes()[(int) specifiedIdxs[i]]; + pointIdxsIn[specifiedAxisIn[i]] = NDArrayIndex.point(sourceIdx); + int outI = (int) specifiedIdxs[i]; + pointIdxsOut[specifiedAxisOut[i]] = NDArrayIndex.point(outI); } - //epsilon equals - if (isScalar() && n.isScalar()) { - if (isZ()) { - val val = getLong(0); - val val2 = n.getLong(0); + out.get(pointIdxsOut).assign(get(pointIdxsIn)); + } - return val == val2; - } else if (isR()) { - val val = getDouble(0); - val val2 = n.getDouble(0); + return out; + } - if (Double.isNaN(val) != Double.isNaN(val2)) - return false; + char order = Shape.getOrder(outShape, outStrides, -1); + INDArray out = create(data, outShape, outStrides, offset, order); + return out; + } - return Math.abs(val - val2) < eps; - } else if (isB()) { - val val = getInt(0); - val val2 = n.getInt(0); - - return val == val2; - } - - } else if (isVector() && n.isVector()) { - val op = new EqualsWithEps(this, n, eps); - Nd4j.exec(op); - val diff = op.z().getDouble(0); - - return diff < 0.5; + @Override + public INDArray getColumns(int... cindices) { + if (!isMatrix() && !isVector()) { + throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); + } + if (isVector()) { + return Nd4j.pullRows(this, 0, cindices, this.ordering()); + } else { + INDArray ret = Nd4j.createUninitialized(this.dataType(), rows(), cindices.length); + for (int i = 0; i < cindices.length; i++) { + ret.putColumn(i, getColumn(cindices[i])); } + return ret; + } - if (!Arrays.equals(this.shape(), n.shape())) - return false; + } + + protected INDArray create(int rows, int length) { + return create(new int[]{rows, length}); + } + + @Override + public INDArray getRow(long r) { + if (isRowVector() && r == 0) { + return this; + } else if (isRowVector() && r > 0) { + throw new IllegalArgumentException( + "Illegal index for row: requested row " + r + " but this.size(0)=" + this.size(0)); + } + + Preconditions.checkArgument(rank() == 2, "getRow() can be called on 2D arrays only"); + Preconditions.checkArgument(r < rows(), "Row index must be smaller than total number of rows"); + + return tensorAlongDimension(r, 1); + } + + @Override + public INDArray getRow(long r, boolean keepDim) { + INDArray row = getRow(r); + if (!keepDim) { + return row; + } + return row.reshape(1, row.length()); + } + + public boolean equalsWithEps(Object o, double eps) { + Nd4j.getCompressor().autoDecompress(this); + + if (o == null) { + return false; + } + + if (!(o instanceof INDArray)) { + return false; + } + + INDArray n = (INDArray) o; + Nd4j.getCompressor().autoDecompress(n); + + if (n == this) { + return true; + } + + if (this.rank() != n.rank()) { + return false; + } + + if (this.length() != n.length()) { + return false; + } + + if (this.isEmpty() != n.isEmpty()) { + return false; + } + + if (this.isEmpty() && n.isEmpty()) { + return Shape.shapeEquals(this.shape(), n.shape()); + } + + if (this.dataType() != n.dataType()) { + return false; + } + + // meh + if (this.dataType() == DataType.UTF8 && n.dataType() == DataType.UTF8) { + for (long e = 0; e < this.length(); e++) { + val str1 = this.getString(e); + val str2 = n.getString(e); + + if (!str1.equals(str2)) { + return false; + } + } + + return true; + } + + //epsilon equals + if (isScalar() && n.isScalar()) { + if (isZ()) { + val val = getLong(0); + val val2 = n.getLong(0); + + return val == val2; + } else if (isR()) { + val val = getDouble(0); + val val2 = n.getDouble(0); + + if (Double.isNaN(val) != Double.isNaN(val2)) { + return false; + } + + return Math.abs(val - val2) < eps; + } else if (isB()) { + val val = getInt(0); + val val2 = n.getInt(0); + + return val == val2; + } + + } else if (isVector() && n.isVector()) { + val op = new EqualsWithEps(this, n, eps); + Nd4j.exec(op); + val diff = op.z().getDouble(0); + + return diff < 0.5; + } + + if (!Arrays.equals(this.shape(), n.shape())) { + return false; + } + + if (!Shape.shapeEquals(shape(), n.shape())) { + return false; + } + + if (slices() != n.slices()) { + return false; + } + + if (n.ordering() == ordering()) { + EqualsWithEps op = new EqualsWithEps(this, n, eps); + Nd4j.getExecutioner().exec(op); + double diff = op.z().getDouble(0); + + return diff < 0.5; + } else { + EqualsWithEps op = new EqualsWithEps(this, n, eps); + Nd4j.getExecutioner().exec(op); + double diff = op.z().getDouble(0); + + return diff < 0.5; + } + } + + @Override + public boolean equalShapes(@NonNull INDArray other) { + if (isEmpty() != other.isEmpty()) { + return false; + } + if (rank() != other.rank()) { + return false; + } + for (int i = 0; i < rank(); i++) { + if (size(i) != other.size(i)) { + return false; + } + } + return true; + } + + /** + * Compare two matrices. Returns true if and only if other is also a DoubleMatrix which has the + * same size and the maximal absolute difference in matrix elements is smaller than 1e-5. + * + * @param o + */ + @Override + public boolean equals(Object o) { + return equalsWithEps(o, Nd4j.EPS_THRESHOLD); + } + + @Override + public int hashCode() { + val longHash = Nd4j.exec(new HashCode(this))[0].getLong(0); + return Math.abs(longHash) <= Integer.MAX_VALUE ? (int) longHash + : (int) (longHash % Integer.MAX_VALUE); + } + + @Override + public DataBuffer shapeInfoDataBuffer() { + return shapeInformation; + } + + @Override + public LongBuffer shapeInfo() { + return shapeInformation.asNioLong(); + } + + public long[] shape() { + return jvmShapeInfo.shape; + } + + @Override + public String shapeInfoToString() { + return Shape.shapeToString(this); + } + + @Override + public long[] stride() { + return jvmShapeInfo.stride; + } - if (!Shape.shapeEquals(shape(), n.shape())) { - return false; - } + @Override + public long offset() { + return data().offset(); + } + @Override + public char ordering() { + return jvmShapeInfo.order; + } - if (slices() != n.slices()) - return false; + @Override + public long size(int dimension) { + if (dimension < 0) { + dimension += jvmShapeInfo.rank; + } - if (n.ordering() == ordering()) { - EqualsWithEps op = new EqualsWithEps(this, n, eps); - Nd4j.getExecutioner().exec(op); - double diff = op.z().getDouble(0); - - return diff < 0.5; + if (isScalar()) { + if (dimension == 0 || dimension == 1 || dimension < 0) { + return length(); } else { - EqualsWithEps op = new EqualsWithEps(this, n, eps); - Nd4j.getExecutioner().exec(op); - double diff = op.z().getDouble(0); - - return diff < 0.5; + throw new IllegalArgumentException("Illegal dimension for scalar " + dimension); } } - @Override - public boolean equalShapes(@NonNull INDArray other){ - if(isEmpty() != other.isEmpty()) - return false; - if(rank() != other.rank()) - return false; - for( int i=0; i= rank()) { + throw new IllegalArgumentException( + "Invalid size: cannot get size of dimension " + dimension + " for rank " + + rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")"); + } + + return jvmShapeInfo.shape[dimension]; + } + + @Override + public int rank() { + return jvmShapeInfo.rank; + } + + @Override + public long length() { + if (isEmpty()) { + return 0; + } + return jvmShapeInfo.length; + } + + @Override + public INDArray broadcast(INDArray result) { + Nd4j.getCompressor().autoDecompress(this); + + val shape = result.shape(); + + if (Shape.shapeEquals(shape, shape())) { + return this; + } + + // if we're on scalar, we can just create new array + if (this.isScalar()) { + return Nd4j.createUninitialized(this.dataType(), shape).assign(this.getDouble(0)); + } + + boolean compatible = true; + int count = shape.length - 1; + int thisCount = jvmShapeInfo.rank - 1; + for (int i = shape.length - 1; i > 0; i--) { + if (count < 0 || thisCount < 0) { + break; } - return true; + if (shape[count] != shape()[thisCount] && shape[count] != 1 && shape()[thisCount] != 1) { + compatible = false; + break; + } + + count--; + thisCount--; } - /** - * Compare two matrices. Returns true if and only if other is also a - * DoubleMatrix which has the same size and the maximal absolute - * difference in matrix elements is smaller than 1e-5. - * - * @param o - */ - @Override - public boolean equals(Object o) { - return equalsWithEps(o, Nd4j.EPS_THRESHOLD); - } + if (!compatible) { + throw new IllegalArgumentException( + "Incompatible broadcast from " + Arrays.toString(shape()) + " to " + + Arrays.toString(shape)); + } - @Override - public int hashCode() { - val longHash = Nd4j.exec(new HashCode(this))[0].getLong(0); - return Math.abs(longHash) <= Integer.MAX_VALUE ? (int) longHash : (int) (longHash % Integer.MAX_VALUE); - } - - @Override - public DataBuffer shapeInfoDataBuffer() { - return shapeInformation; - } - - @Override - public LongBuffer shapeInfo() { - return shapeInformation.asNioLong(); - } - - public long[] shape() { - return jvmShapeInfo.shape; - } - - @Override - public String shapeInfoToString() { - return Shape.shapeToString(this); - } - - @Override - public long[] stride() { - return jvmShapeInfo.stride; - } - - - @Override - public long offset() { - return data().offset(); - } - - @Override - public char ordering() { - return jvmShapeInfo.order; - } - - @Override - public long size(int dimension) { - if (dimension < 0) - dimension += jvmShapeInfo.rank; - - if (isScalar()) { - if (dimension == 0 || dimension == 1 || dimension < 0) - return length(); - else - throw new IllegalArgumentException("Illegal dimension for scalar " + dimension); - } - - if (dimension >= rank()) - throw new IllegalArgumentException("Invalid size: cannot get size of dimension " + dimension + " for rank " - + rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")"); - - return jvmShapeInfo.shape[dimension]; - } - - @Override - public int rank() { - return jvmShapeInfo.rank; - } - - @Override - public long length() { - if (isEmpty()) - return 0; - return jvmShapeInfo.length; - } - - @Override - public INDArray broadcast(INDArray result) { - Nd4j.getCompressor().autoDecompress(this); - - val shape = result.shape(); - - if (Shape.shapeEquals(shape, shape())) - return this; - - // if we're on scalar, we can just create new array - if (this.isScalar()) - return Nd4j.createUninitialized(this.dataType(), shape).assign(this.getDouble(0)); - - - - - boolean compatible = true; - int count = shape.length - 1; - int thisCount = jvmShapeInfo.rank - 1; - for (int i = shape.length - 1; i > 0; i--) { - if (count < 0 || thisCount < 0) - break; - if (shape[count] != shape()[thisCount] && shape[count] != 1 && shape()[thisCount] != 1) { - compatible = false; - break; - } - - count--; - thisCount--; - } - - if (!compatible) - throw new IllegalArgumentException("Incompatible broadcast from " + Arrays.toString(shape()) + " to " - + Arrays.toString(shape)); - - - - long[] retShape = new long[shape.length]; - List broadCastDimensions = new ArrayList<>(); - List nonBroadCastDimensions = new ArrayList<>(); - for (int i = 0; i < retShape.length; i++) { - if (shape().length == 1) { - if (i == 0) { - if (i < shape().length) - retShape[i] = Math.max(1, shape[i]); - else - retShape[i] = shape[i]; - } else { - if (i < shape().length) - retShape[i] = Math.max(shape[i], size(i)); - else - retShape[i] = shape[i]; - } + long[] retShape = new long[shape.length]; + List broadCastDimensions = new ArrayList<>(); + List nonBroadCastDimensions = new ArrayList<>(); + for (int i = 0; i < retShape.length; i++) { + if (shape().length == 1) { + if (i == 0) { + if (i < shape().length) { + retShape[i] = Math.max(1, shape[i]); } else { - if (i < rank() && size(i) == 1) - broadCastDimensions.add(i); - else - nonBroadCastDimensions.add(i); - if (i < shape().length) - retShape[i] = Math.max(shape[i], size(i)); - else - retShape[i] = shape[i]; + retShape[i] = shape[i]; } - - } - - - if (isRowVector()) { - //number of times to repeat each value - for (int i = 0; i < result.slices(); i++) { - result.putSlice(i, this); - } - } else if (isColumnVector()) { - for (int i = 0; i < result.columns(); i++) { - result.putColumn(i, this); + } else { + if (i < shape().length) { + retShape[i] = Math.max(shape[i], size(i)); + } else { + retShape[i] = shape[i]; } } - - else { - int[] repeat = new int[shape.length]; - for(int i = 0; i < shape.length; i++) { - if(i < rank()) { - if(size(i) == 1) - repeat[i] = (int) shape[i]; - else { - repeat[i] = 1; - } - } - - else { - repeat[i] = (int) shape[i]; - } - } - - if (this.isView()) { - Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this.dup(this.ordering())},new INDArray[]{result},repeat)); - } else - Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this},new INDArray[]{result},repeat)); - } - return result; + } else { + if (i < rank() && size(i) == 1) { + broadCastDimensions.add(i); + } else { + nonBroadCastDimensions.add(i); + } + if (i < shape().length) { + retShape[i] = Math.max(shape[i], size(i)); + } else { + retShape[i] = shape[i]; + } + } } - @Override - public INDArray broadcast(long... shape) { - return broadcast(Nd4j.createUninitialized(this.dataType(), shape, this.ordering())); + if (isRowVector()) { + //number of times to repeat each value + for (int i = 0; i < result.slices(); i++) { + result.putSlice(i, this); + } + } else if (isColumnVector()) { + for (int i = 0; i < result.columns(); i++) { + result.putColumn(i, this); + } + } else { + int[] repeat = new int[shape.length]; + for (int i = 0; i < shape.length; i++) { + if (i < rank()) { + if (size(i) == 1) { + repeat[i] = (int) shape[i]; + } else { + repeat[i] = 1; + } + } else { + repeat[i] = (int) shape[i]; + } + } + + if (this.isView()) { + Nd4j.getExecutioner().execAndReturn( + new Tile(new INDArray[]{this.dup(this.ordering())}, new INDArray[]{result}, + repeat)); + } else { + Nd4j.getExecutioner() + .execAndReturn(new Tile(new INDArray[]{this}, new INDArray[]{result}, repeat)); + } + } + return result; + + } + + @Override + public INDArray broadcast(long... shape) { + return broadcast(Nd4j.createUninitialized(this.dataType(), shape, this.ordering())); + } + + @Deprecated + @Override + public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) { + return dimShuffle(rearrange, ArrayUtil.toLongArray(newOrder), broadCastable); + } + + /** + * Dimshuffle: an extension of permute that adds the ability to broadcast various dimensions. + *

      + * See theano for more examples. This will only accept integers and xs. + *

      + * An x indicates a dimension should be broadcasted rather than permuted. + * + * @param rearrange the dimensions to swap to + * @return the newly permuted array + */ + @Override + public INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable) { + Nd4j.getCompressor().autoDecompress(this); + + if (broadCastable.length != jvmShapeInfo.rank) { + throw new IllegalArgumentException( + "The broadcastable dimensions must be the same length as the current shape"); + } + + boolean broadcast = false; + Set set = new HashSet<>(); + for (int i = 0; i < rearrange.length; i++) { + set.add(rearrange[i]); + if (rearrange[i] instanceof Integer) { + Integer j = (Integer) rearrange[i]; + if (j >= broadCastable.length) { + throw new IllegalArgumentException( + "Illegal dimension, dimension must be < broadcastable.length (aka the real dimensions"); + } + } else if (rearrange[i] instanceof Character) { + Character c = (Character) rearrange[i]; + if (c != 'x') { + throw new IllegalArgumentException("Illegal input: Must be x"); + } + broadcast = true; + + } else { + throw new IllegalArgumentException("Only characters and integers allowed"); + } } - @Deprecated - @Override - public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) { - return dimShuffle(rearrange, ArrayUtil.toLongArray(newOrder), broadCastable); + //just do permute + if (!broadcast) { + int[] ret = new int[rearrange.length]; + for (int i = 0; i < ret.length; i++) { + ret[i] = (Integer) rearrange[i]; + } + return permute(ret); + } else { + List drop = new ArrayList<>(); + for (int i = 0; i < broadCastable.length; i++) { + if (!set.contains(i)) { + if (broadCastable[i]) { + drop.add(i); + } else { + throw new IllegalArgumentException( + "We can't drop the given dimension because its not broadcastable"); + } + } + + } + + //list of dimensions to keep + int[] shuffle = new int[broadCastable.length]; + int count = 0; + for (int i = 0; i < rearrange.length; i++) { + if (rearrange[i] instanceof Integer) { + shuffle[count++] = (Integer) rearrange[i]; + } + } + + List augment = new ArrayList<>(); + for (int i = 0; i < rearrange.length; i++) { + if (rearrange[i] instanceof Character) { + augment.add(i); + } + } + + Integer[] augmentDims = augment.toArray(new Integer[1]); + + count = 0; + + int dropIdx = 0; + int[] newShape = new int[shuffle.length + drop.size()]; + for (int i = 0; i < newShape.length; i++) { + if (i < shuffle.length) { + newShape[count++] = shuffle[i]; + } else { + newShape[count++] = drop.get(dropIdx++); + } + } + + INDArray ret; //TODO is this correct? This was old behaviour before adding permute input check + if (newShape.length == this.rank()) { + ret = permute(newShape); + } else { + ret = dup(); + } + List newDims = new ArrayList<>(); + long[] shape = Arrays.copyOfRange(ret.shape(), 0, shuffle.length); + for (int i = 0; i < shape.length; i++) { + newDims.add(shape[i]); + } + + for (int i = 0; i < augmentDims.length; i++) { + newDims.add(augmentDims[i], 1L); + } + + long[] toReshape = ArrayUtil.toArrayLong(newDims); + + ret = ret.reshape(toReshape); + return ret; + } - /** - * Dimshuffle: an extension of permute that adds the ability - * to broadcast various dimensions. - *

      - * See theano for more examples. - * This will only accept integers and xs. - *

      - * An x indicates a dimension should be broadcasted rather than permuted. - * - * @param rearrange the dimensions to swap to - * @return the newly permuted array - */ - @Override - public INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable) { - Nd4j.getCompressor().autoDecompress(this); - if (broadCastable.length != jvmShapeInfo.rank) + } + + @Override + public INDArray permute(int... rearrange) { + Preconditions.checkArgument(rearrange.length == rank(), + "Incorrect number of arguments for permute function:" + + " got arguments %s for rank %s array. Number of arguments must equal array rank", + rearrange, rank()); + Nd4j.getCompressor().autoDecompress(this); + boolean alreadyInOrder = true; + //IntBuffer shapeInfo = shapeInfo(); + int rank = jvmShapeInfo.rank; + for (int i = 0; i < rank; i++) { + if (rearrange[i] != i) { + alreadyInOrder = false; + break; + } + } + + if (alreadyInOrder) { + return this; + } + + checkArrangeArray(rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); + + char newOrder = Shape.getOrder(newShape, newStride, 1); + + INDArray value = create(data(), newShape, newStride, offset(), newOrder); + return value; + } + + @Override + public INDArray permutei(int... rearrange) { + Preconditions.checkArgument(rearrange.length == rank(), + "Incorrect number of arguments for permute function:" + + " got arguments %s for rank %s array. Number of arguments must equal array rank", + rearrange, rank()); + boolean alreadyInOrder = true; + val shapeInfo = shapeInfo(); + int rank = jvmShapeInfo.rank; + for (int i = 0; i < rank; i++) { + if (rearrange[i] != i) { + alreadyInOrder = false; + break; + } + } + + if (alreadyInOrder) { + return this; + } + + checkArrangeArray(rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); + char newOrder = Shape.getOrder(newShape, newStride, 1); + + val ews = shapeInfo.get(2 * rank + 2); + + val si = Nd4j.getShapeInfoProvider() + .createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty()); + setShapeInformation(si); + + if (shapeInfo.get(2 * rank + 2) > 0) { + //for the backend to work - no ews for permutei + //^^ not true anymore? Not sure here. Marking this for raver + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(newShape, newStride, 0, newOrder, dataType(), isEmpty())); + } + + //this.shape = null; + //this.stride = null; + + return this; + } + + + @Deprecated + protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) { + val ret = new long[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape.get(rearrange[i]); + } + return ret; + } + + @Deprecated + protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) { + int[] ret = new int[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape.get(rearrange[i]); + } + return ret; + } + + @Deprecated + protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) { + int[] ret = new int[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape.getInt(rearrange[i]); + } + return ret; + } + + protected long[] doPermuteSwap(long[] shape, int[] rearrange) { + val ret = new long[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape[rearrange[i]]; + } + + return ret; + } + + + protected void checkArrangeArray(int[] arr) { + Preconditions.checkArgument(arr.length == jvmShapeInfo.rank, + "Invalid rearrangement: number of arrangement (%s) != rank (%s)", + arr.length, jvmShapeInfo.rank); + for (int i = 0; i < arr.length; i++) { + if (arr[i] >= arr.length) { throw new IllegalArgumentException( - "The broadcastable dimensions must be the same length as the current shape"); - - boolean broadcast = false; - Set set = new HashSet<>(); - for (int i = 0; i < rearrange.length; i++) { - set.add(rearrange[i]); - if (rearrange[i] instanceof Integer) { - Integer j = (Integer) rearrange[i]; - if (j >= broadCastable.length) - throw new IllegalArgumentException( - "Illegal dimension, dimension must be < broadcastable.length (aka the real dimensions"); - } else if (rearrange[i] instanceof Character) { - Character c = (Character) rearrange[i]; - if (c != 'x') - throw new IllegalArgumentException("Illegal input: Must be x"); - broadcast = true; - - } else - throw new IllegalArgumentException("Only characters and integers allowed"); + "The specified dimensions can't be swapped. Given element " + i + + " was >= number of dimensions"); } - - //just do permute - if (!broadcast) { - int[] ret = new int[rearrange.length]; - for (int i = 0; i < ret.length; i++) - ret[i] = (Integer) rearrange[i]; - return permute(ret); - } else { - List drop = new ArrayList<>(); - for (int i = 0; i < broadCastable.length; i++) { - if (!set.contains(i)) { - if (broadCastable[i]) - drop.add(i); - else - throw new IllegalArgumentException( - "We can't drop the given dimension because its not broadcastable"); - } - - } - - - //list of dimensions to keep - int[] shuffle = new int[broadCastable.length]; - int count = 0; - for (int i = 0; i < rearrange.length; i++) { - if (rearrange[i] instanceof Integer) { - shuffle[count++] = (Integer) rearrange[i]; - } - } - - - List augment = new ArrayList<>(); - for (int i = 0; i < rearrange.length; i++) { - if (rearrange[i] instanceof Character) - augment.add(i); - } - - Integer[] augmentDims = augment.toArray(new Integer[1]); - - count = 0; - - int dropIdx = 0; - int[] newShape = new int[shuffle.length + drop.size()]; - for (int i = 0; i < newShape.length; i++) { - if (i < shuffle.length) { - newShape[count++] = shuffle[i]; - } else - newShape[count++] = drop.get(dropIdx++); - } - - INDArray ret; //TODO is this correct? This was old behaviour before adding permute input check - if(newShape.length == this.rank()){ - ret = permute(newShape); - } else { - ret = dup(); - } - List newDims = new ArrayList<>(); - long[] shape = Arrays.copyOfRange(ret.shape(), 0, shuffle.length); - for (int i = 0; i < shape.length; i++) { - newDims.add(shape[i]); - } - - for (int i = 0; i < augmentDims.length; i++) { - newDims.add(augmentDims[i], 1L); - } - - long[] toReshape = ArrayUtil.toArrayLong(newDims); - - - ret = ret.reshape(toReshape); - return ret; - + if (arr[i] < 0) { + throw new IllegalArgumentException("Invalid dimension: " + i + " : negative value"); } } - @Override - public INDArray permute(int... rearrange) { - Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + - " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); - Nd4j.getCompressor().autoDecompress(this); - boolean alreadyInOrder = true; - //IntBuffer shapeInfo = shapeInfo(); - int rank = jvmShapeInfo.rank; - for (int i = 0; i < rank; i++) { - if (rearrange[i] != i) { - alreadyInOrder = false; - break; - } - } - - if (alreadyInOrder) - return this; - - checkArrangeArray(rearrange); - val newShape = doPermuteSwap(shape(), rearrange); - val newStride = doPermuteSwap(stride(), rearrange); - - char newOrder = Shape.getOrder(newShape, newStride, 1); - - INDArray value = create(data(), newShape, newStride, offset(), newOrder); - return value; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr.length; j++) { + if (i != j && arr[i] == arr[j]) { + throw new IllegalArgumentException("Permute array must have unique elements"); + } + } } - @Override - public INDArray permutei(int... rearrange) { - Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + - " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); - boolean alreadyInOrder = true; - val shapeInfo = shapeInfo(); - int rank = jvmShapeInfo.rank; - for (int i = 0; i < rank; i++) { - if (rearrange[i] != i) { - alreadyInOrder = false; - break; - } - } + } - if (alreadyInOrder) - return this; - - checkArrangeArray(rearrange); - val newShape = doPermuteSwap(shape(), rearrange); - val newStride = doPermuteSwap(stride(), rearrange); - char newOrder = Shape.getOrder(newShape, newStride, 1); - - val ews = shapeInfo.get(2 * rank + 2); - - val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty()); - setShapeInformation(si); - - - if (shapeInfo.get(2 * rank + 2) > 0) { - //for the backend to work - no ews for permutei - //^^ not true anymore? Not sure here. Marking this for raver - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, 0, newOrder, dataType(), isEmpty())); - } - - //this.shape = null; - //this.stride = null; - - - return this; - } - - - @Deprecated - protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) { - val ret = new long[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape.get(rearrange[i]); - } - return ret; - } - - @Deprecated - protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) { - int[] ret = new int[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape.get(rearrange[i]); - } - return ret; - } - - @Deprecated - protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) { - int[] ret = new int[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape.getInt(rearrange[i]); - } - return ret; - } - - protected long[] doPermuteSwap(long[] shape, int[] rearrange) { - val ret = new long[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape[rearrange[i]]; - } - - return ret; - } - - - protected void checkArrangeArray(int[] arr) { - Preconditions.checkArgument(arr.length == jvmShapeInfo.rank, "Invalid rearrangement: number of arrangement (%s) != rank (%s)", - arr.length, jvmShapeInfo.rank); - for (int i = 0; i < arr.length; i++) { - if (arr[i] >= arr.length) - throw new IllegalArgumentException("The specified dimensions can't be swapped. Given element " + i - + " was >= number of dimensions"); - if (arr[i] < 0) - throw new IllegalArgumentException("Invalid dimension: " + i + " : negative value"); - - - } - - for (int i = 0; i < arr.length; i++) { - for (int j = 0; j < arr.length; j++) { - if (i != j && arr[i] == arr[j]) - throw new IllegalArgumentException("Permute array must have unique elements"); - } - } - - } - - protected void autoProcessScalarCall() { + protected void autoProcessScalarCall() { /* if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC) OpProfiler.getInstance().processScalarCall();*/ + } + + /** + * Checks whether the matrix is a vector. + */ + @Override + public boolean isVector() { + if (jvmShapeInfo.rank == 1) { + return true; + } + + return isRowVector() || isColumnVector(); + } + + @Override + public boolean isVectorOrScalar() { + return isVector() || isScalar(); + } + + @Override + public boolean isSquare() { + return isMatrix() && rows() == columns(); + } + + @Override + public boolean isRowVector() { + return (rank() == 2 && rows() == 1) && length() > 1 || rank() == 1 && length() > 1; + } + + @Override + public boolean isColumnVector() { + return rank() == 2 && columns() == 1 && length() > 1; + } + + @Override + public boolean isColumnVectorOrScalar() { + return isColumnVector() || isScalar(); + } + + @Override + public boolean isRowVectorOrScalar() { + return isRowVector() || isScalar(); + } + + /** + * Generate string representation of the matrix. Printing will switch to scientific notation on a + * per element basis - when abs value is greater than or equal to 10000 - when abs value is less + * than or equal to 0.0001 and not zero + *

      + * If the number of elements in the array is greater than 1000 (by default) only the first and + * last three elements in a dimension are included. This can be changed globally using + * {@link NDArrayStrings#setMaxPrintElements(long)} + */ + @Override + public String toString() { + return toString(new NDArrayStrings()); + } + + + @Override + public String toString(@NonNull NDArrayStrings options) { + if (wasClosed()) { + return ""; + } + if (!isCompressed() && !preventUnpack) { + return options.format(this); + } else if (isCompressed() && compressDebug) { + return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered."; + } else if (preventUnpack) { + return "Array string unpacking is disabled."; + } + return options.format(this); + } + + @Override + public String toString(long maxElements, boolean forceSummarize, int precision) { + return toString(new NDArrayStrings(maxElements, forceSummarize, precision)); + } + + + @Override + public String toStringFull() { + return toString(Long.MAX_VALUE, false, -1 * dataType().precision()); + } + + @Override + public Object element() { + + if (!isScalar()) { + throw new IllegalStateException("Unable to retrieve element from non scalar matrix"); + } + if (data.dataType() == DataType.FLOAT) { + return data.getFloat(0); + } + return data.getDouble(0); + } + + @Override + public INDArray remainder(INDArray denominator) { + if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { + return remainder(denominator, Nd4j.createUninitialized(this.dataType(), + Shape.broadcastOutputShape(this.shape(), denominator.shape()))); + } else { + return remainder(denominator, this.ulike()); + } + } + + @Override + public INDArray remainder(INDArray denominator, INDArray result) { + validateNumericalArray("remainder", false); + Preconditions.checkArgument(Shape.areShapesBroadcastable(this.shape(), denominator.shape()), + "Shapes must be broadcastable"); + + val op = new RemainderOp(this, denominator, result); + Nd4j.getExecutioner().exec(op); + return result; + } + + @Override + public INDArray remainder(Number denominator) { + return remainder(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); + } + + @Override + public INDArray remainder(Number denominator, INDArray result) { + validateNumericalArray("remainder", false); + + ScalarRemainder op = new ScalarRemainder(this, null, result, denominator); + Nd4j.getExecutioner().exec(op); + return result; + } + + @Override + public INDArray remainderi(INDArray denominator) { + validateNumericalArray("remainderi", false); + RemainderOp op = new RemainderOp(this, denominator, this); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public INDArray remainderi(Number denominator) { + validateNumericalArray("remainderi", false); + ScalarRemainder op = new ScalarRemainder(this, null, this, denominator); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public INDArray fmod(INDArray denominator) { + validateNumericalArray("fmod", false); + if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { + return fmod(denominator, Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), + Shape.broadcastOutputShape(this.shape(), denominator.shape()))); + } else { + return fmod(denominator, this.ulike()); + } + } + + @Override + public INDArray fmod(INDArray denominator, INDArray result) { + validateNumericalArray("fmod", false); + if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { + val outShape = Shape.broadcastOutputShape(this.shape(), denominator.shape()); + Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), + "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); + + Nd4j.exec(new FloorModOp(new INDArray[]{this, denominator}, new INDArray[]{result})); + + return result; + } else { + FModOp op = new FModOp(this, denominator, result); + Nd4j.getExecutioner().exec(op); + return result; + } + } + + @Override + public INDArray fmod(Number denominator) { + return fmod(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); + } + + @Override + public INDArray fmod(Number denominator, INDArray result) { + validateNumericalArray("fmod", false); + ScalarFMod op = new ScalarFMod(this, null, result, denominator); + Nd4j.getExecutioner().exec(op); + return result; + } + + @Override + public INDArray fmodi(INDArray denominator) { + validateNumericalArray("fmodi", false); + FModOp op = new FModOp(this, denominator, this); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public INDArray fmodi(Number denominator) { + validateNumericalArray("fmodi", false); + ScalarFMod op = new ScalarFMod(this, null, this, denominator); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public Iterator iterator() { + return new FirstAxisIterator(this); + } + + @Override + public long originalOffset() { + if (data().originalOffset() >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Original offset of buffer can not be >= Integer.MAX_VALUE"); + } + + return data().originalOffset(); + } + + private void readObject(ObjectInputStream s) { + try { + s.defaultReadObject(); + read(s); + } catch (Exception e) { + throw new RuntimeException(e); } - /** - * Checks whether the matrix is a vector. - */ - @Override - public boolean isVector() { - if (jvmShapeInfo.rank == 1) - return true; + } - return isRowVector() || isColumnVector(); + private void writeObject(ObjectOutputStream out) throws IOException { + out.defaultWriteObject(); + write(out); + } + + //Custom serialization for Java serialization + protected void write(ObjectOutputStream out) throws IOException { + if (this.isView()) { + //As per Nd4j.write, duplicate before writing to the output stream + //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here + //Furthermore, because we only want to save the *actual* data for a view (not the full data), the shape info + // (mainly strides, offset, element-wise stride) may be different in the duped array vs. the view array + INDArray copy = this.dup(); + copy.shapeInfoDataBuffer().write(out); + copy.data().write(out); + } else { + shapeInformation.write(out); + data().write(out); } + } - @Override - public boolean isVectorOrScalar() { - return isVector() || isScalar(); - } + //Custom deserialization for Java serialization + protected void read(ObjectInputStream s) { + val headerShape = BaseDataBuffer.readHeader(s); - @Override - public boolean isSquare() { - return isMatrix() && rows() == columns(); - } + shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())]); + shapeInformation.read(s, headerShape.getLeft(), headerShape.getMiddle(), + headerShape.getRight()); - @Override - public boolean isRowVector() { - return (rank() == 2 && rows() == 1) && length() > 1 || rank() == 1 && length() > 1; - } + setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong())); - @Override - public boolean isColumnVector() { - return rank() == 2 && columns() == 1 && length() > 1; - } + val headerData = BaseDataBuffer.readHeader(s); + data = Nd4j.createBuffer(headerData.getRight(), headerData.getMiddle(), false); + data().read(s, headerData.getLeft(), headerData.getMiddle(), headerData.getRight()); + } - @Override - public boolean isColumnVectorOrScalar() { - return isColumnVector() || isScalar(); - } + @Override + public INDArray argMax(int... dimension) { + return Nd4j.argMax(this, dimension); + } - @Override - public boolean isRowVectorOrScalar() { - return isRowVector() || isScalar(); - } + @Override + public boolean isAttached() { + if (isEmpty()) { + return false; + } - /** - * Generate string representation of the matrix. - * Printing will switch to scientific notation on a per element basis - * - when abs value is greater than or equal to 10000 - * - when abs value is less than or equal to 0.0001 and not zero - * - * If the number of elements in the array is greater than 1000 (by default) only the first and last three elements - * in a dimension are included. This can be changed globally using {@link NDArrayStrings#setMaxPrintElements(long)} - * - * - */ - @Override - public String toString() { - return toString(new NDArrayStrings()); - } + Preconditions.checkArgument(!(data == null && !isEmpty()), "Array has no buffer!"); + return data.isAttached() || + (data.underlyingDataBuffer() != null && data.underlyingDataBuffer().isAttached()) || + (data.originalDataBuffer() != null && data.originalDataBuffer().isAttached()); + } - @Override - public String toString(@NonNull NDArrayStrings options) { - if(wasClosed()) - return ""; - if (!isCompressed() && !preventUnpack) - return options.format(this); - else if (isCompressed() && compressDebug) - return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered."; - else if (preventUnpack) - return "Array string unpacking is disabled."; - return options.format(this); - } + @Override + public boolean isInScope() { + if (!isAttached()) { + return true; + } - @Override - public String toString(long maxElements, boolean forceSummarize, int precision){ - return toString(new NDArrayStrings(maxElements, forceSummarize, precision)); - } + return data.isInScope(); + } + @Override + public INDArray detach() { + if (!isAttached()) { + return this; + } - @Override - public String toStringFull(){ - return toString(Long.MAX_VALUE, false, -1 * dataType().precision()); - } + WorkspaceUtils.assertValidArray(this, "Cannot detach INDArray"); - @Override - public Object element() { - - if (!isScalar()) - throw new IllegalStateException("Unable to retrieve element from non scalar matrix"); - if (data.dataType() == DataType.FLOAT) - return data.getFloat(0); - return data.getDouble(0); - } - - @Override - public INDArray remainder(INDArray denominator) { - if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { - return remainder(denominator, Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); - } else - return remainder(denominator, this.ulike()); - } - - @Override - public INDArray remainder(INDArray denominator, INDArray result) { - validateNumericalArray("remainder", false); - Preconditions.checkArgument(Shape.areShapesBroadcastable(this.shape(), denominator.shape()),"Shapes must be broadcastable"); - - val op = new RemainderOp(this, denominator, result); - Nd4j.getExecutioner().exec(op); - return result; - } - - @Override - public INDArray remainder(Number denominator) { - return remainder(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); - } - - @Override - public INDArray remainder(Number denominator, INDArray result) { - validateNumericalArray("remainder", false); - - ScalarRemainder op = new ScalarRemainder(this, null, result, denominator); - Nd4j.getExecutioner().exec(op); - return result; - } - - @Override - public INDArray remainderi(INDArray denominator) { - validateNumericalArray("remainderi", false); - RemainderOp op = new RemainderOp(this, denominator, this); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public INDArray remainderi(Number denominator) { - validateNumericalArray("remainderi", false); - ScalarRemainder op = new ScalarRemainder(this, null, this, denominator); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public INDArray fmod(INDArray denominator) { - validateNumericalArray("fmod", false); - if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { - return fmod(denominator, Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); - } else - return fmod(denominator, this.ulike()); - } - - @Override - public INDArray fmod(INDArray denominator, INDArray result) { - validateNumericalArray("fmod", false); - if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { - val outShape = Shape.broadcastOutputShape(this.shape(), denominator.shape()); - Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); - - Nd4j.exec(new FloorModOp(new INDArray[]{this, denominator}, new INDArray[]{result})); - - return result; - } else { - FModOp op = new FModOp(this, denominator, result); - Nd4j.getExecutioner().exec(op); - return result; - } - } - - @Override - public INDArray fmod(Number denominator) { - return fmod(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); - } - - @Override - public INDArray fmod(Number denominator, INDArray result) { - validateNumericalArray("fmod", false); - ScalarFMod op = new ScalarFMod(this, null, result, denominator); - Nd4j.getExecutioner().exec(op); - return result; - } - - @Override - public INDArray fmodi(INDArray denominator) { - validateNumericalArray("fmodi", false); - FModOp op = new FModOp(this, denominator, this); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public INDArray fmodi(Number denominator) { - validateNumericalArray("fmodi", false); - ScalarFMod op = new ScalarFMod(this, null, this, denominator); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public Iterator iterator() { - return new FirstAxisIterator(this); - } - - @Override - public long originalOffset() { - if (data().originalOffset() >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Original offset of buffer can not be >= Integer.MAX_VALUE"); - - return data().originalOffset(); - } - - private void readObject(ObjectInputStream s) { - try { - s.defaultReadObject(); - read(s); - } catch (Exception e) { - throw new RuntimeException(e); - } - - } - - private void writeObject(ObjectOutputStream out) throws IOException { - out.defaultWriteObject(); - write(out); - } - - //Custom serialization for Java serialization - protected void write(ObjectOutputStream out) throws IOException { - if (this.isView()) { - //As per Nd4j.write, duplicate before writing to the output stream - //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here - //Furthermore, because we only want to save the *actual* data for a view (not the full data), the shape info - // (mainly strides, offset, element-wise stride) may be different in the duped array vs. the view array - INDArray copy = this.dup(); - copy.shapeInfoDataBuffer().write(out); - copy.data().write(out); - } else { - shapeInformation.write(out); - data().write(out); - } - } - - //Custom deserialization for Java serialization - protected void read(ObjectInputStream s) { - val headerShape = BaseDataBuffer.readHeader(s); - - shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())]); - shapeInformation.read(s, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getRight()); - - setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong())); - - val headerData = BaseDataBuffer.readHeader(s); - data = Nd4j.createBuffer(headerData.getRight(), headerData.getMiddle(), false); - data().read(s, headerData.getLeft(), headerData.getMiddle(), headerData.getRight()); - } - - @Override - public INDArray argMax(int... dimension) { - return Nd4j.argMax(this, dimension); - } - - @Override - public boolean isAttached() { - if (isEmpty()) - return false; - - Preconditions.checkArgument(!(data == null && !isEmpty()), "Array has no buffer!"); - - return data.isAttached() || - (data.underlyingDataBuffer() != null && data.underlyingDataBuffer().isAttached()) || - (data.originalDataBuffer() != null && data.originalDataBuffer().isAttached()); - } - - @Override - public boolean isInScope() { - if (!isAttached()) - return true; - - return data.isInScope(); - } - - @Override - public INDArray detach() { - if (!isAttached()) - return this; - - WorkspaceUtils.assertValidArray(this, "Cannot detach INDArray"); - - Nd4j.getExecutioner().commit(); + Nd4j.getExecutioner().commit(); /* two options here 1) we're within some workspace 2) we're out of any workspace */ - if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) { - if (!isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); + if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) { + if (!isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - return Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - INDArray copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); - copy.assign(this); - Nd4j.getExecutioner().commit(); - - return copy; - } - } else { - MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - INDArray copy = null; - - if (!isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - - //Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType())); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); //this.dup(this.ordering()); - - - } else { - copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); - copy.assign(this); - Nd4j.getExecutioner().commit(); - } - - Nd4j.getMemoryManager().setCurrentWorkspace(workspace); - - return copy; - } - } - - @Override - public INDArray leverage() { - WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); - if (!isAttached()) - return this; - - MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); - if (workspace == null) { - return this.detach(); - } - - MemoryWorkspace parentWorkspace = workspace.getParentWorkspace(); - - if (this.data.getParentWorkspace() == parentWorkspace) - return this; - - // if there's no parent ws - just detach - if (parentWorkspace == null) - return this.detach(); - else { - Nd4j.getExecutioner().commit(); - - // temporary set parent ws as current ws - Nd4j.getMemoryManager().setCurrentWorkspace(parentWorkspace); - - INDArray copy = null; - if (!this.isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - copy = this.dup(this.ordering()); - Nd4j.getExecutioner().commit(); - } - - // restore current ws - Nd4j.getMemoryManager().setCurrentWorkspace(workspace); - return copy; - } - } - - @Override - public INDArray leverageTo(String id) { - return leverageTo(id, false); - } - - @Override - public INDArray leverageTo(String id, boolean enforceExistence) throws Nd4jNoSuchWorkspaceException { - WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); - if (!isAttached()) - return this; - - if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) { - if(enforceExistence){ - throw new Nd4jNoSuchWorkspaceException(id); - } else { - return this; - } - } - - MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); - MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id); - - if (this.data.getParentWorkspace() == target) - return this; - - Nd4j.getMemoryManager().setCurrentWorkspace(target); - INDArray copy = null; - if (!this.isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - copy = this.dup(this.ordering()); - Nd4j.getExecutioner().commit(); - } - - Nd4j.getMemoryManager().setCurrentWorkspace(current); - - return copy; - } - - public INDArray leverageOrDetach(String id){ - if(!isAttached()){ - return this; - } - - if(!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(id)){ - return detach(); - } - return leverageTo(id); - } - - @Override - public INDArray migrate() { - return migrate(false); - } - - @Override - public INDArray migrate(boolean detachOnNoWs){ - WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); - - MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); - - if (current == null) { - if(detachOnNoWs){ - return detach(); - } else { - return this; - } - } - - INDArray copy = null; - - if (!this.isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - copy = this.dup(this.ordering()); - Nd4j.getExecutioner().commit(); - } - - return copy; - } - - @Override - public Number percentileNumber(Number quantile) { - validateNumericalArray("percentileNumber", false); - if (quantile.intValue() < 0 || quantile.intValue() > 100) - throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); - - if (isScalar()) - return this.getDouble(0); - - INDArray sorted = Nd4j.sort(this.dup(this.ordering()), true); - - return getPercentile(quantile, sorted); - } - - @Override - public Number medianNumber() { - validateNumericalArray("medianNumber", false); - if(isScalar()) - return getNumber(0); - return percentileNumber(50); - } - - @Override - public INDArray median(int... dimension) { - validateNumericalArray("median", false); - //Check edge case: size 1 element. No dimension == full array - if(dimension.length == 0){ - return Nd4j.scalar(dataType(), medianNumber().doubleValue()); - } - long shapeProd = 1; - for (int d : dimension) { - shapeProd *= size(d); - } - if (shapeProd == 1) { - long[] newShape = ArrayUtil.removeIndex(shape(), dimension); - return dup('c').reshape('c', newShape); - } - return percentile(50, dimension); - } - - protected double getPercentile(Number quantile, INDArray sorted) { - validateNumericalArray("getPercentile", false); - if (quantile.intValue() == 0) - return sorted.getDouble(0); - else if (quantile.intValue() == 100) - return sorted.getDouble(sorted.length() - 1); - - double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); - if (pos < 1) - return sorted.getDouble(0); - else if (pos >= sorted.length()) - return sorted.getDouble(sorted.length() - 1); - - double fposition = FastMath.floor(pos); - int position = (int)fposition; - - double diff = pos - fposition; - - double lower = sorted.getDouble(position-1); - double upper = sorted.getDouble(position); - - return lower + diff * (upper - lower); - } - - @Override - public INDArray percentile(Number quantile, int... dimension) { - validateNumericalArray("percentile", false); - if (quantile.doubleValue() < 0 || quantile.doubleValue() > 100) - throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); - - if (isScalar()) - return Nd4j.scalar(this.getDouble(0)); - - INDArray sorted = Nd4j.getNDArrayFactory().sort(this.dup(this.ordering()), false, dimension); - - // there's no practical sense doing this on GPU, stride will be just size of TAD. - INDArray ret = Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), sorted.tensorsAlongDimension(dimension)); - for (int i = 0; i < ret.length(); i++) { - ret.putScalar(i, getPercentile(quantile, sorted.tensorAlongDimension(i, dimension))); - } - - return ret; - - } - - protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer); - - @Override - public int toFlatArray(FlatBufferBuilder builder) { - if(isView()){ - return dup(this.ordering()).toFlatArray(builder); - } - int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong()); - int buffer = this.isEmpty() ? 0 : this.dataType() == DataType.UTF8 ? stringBuffer(builder, this.data()) : FlatArray.createBufferVector(builder, this.data().asBytes()); - val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(this.dataType()) : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType()); - int array = FlatArray.createFlatArray(builder, shape, buffer, type, ByteOrder.BE); - - return array; - } - - protected static DataTypeEx convertType(DataType type) { - if (type == DataType.HALF) { - return DataTypeEx.FLOAT16; - } else if (type == DataType.FLOAT) { - return DataTypeEx.FLOAT; - } else if (type == DataType.DOUBLE) { - return DataTypeEx.DOUBLE; - - } else if(type == DataType.INT) { - return DataTypeEx.INT8; - } else if(type == DataType.LONG) { - return DataTypeEx.INT16; - - } else - throw new IllegalStateException("Unknown dataType: [" + type + "]"); - } - - @Override - public boolean isEmpty() { - return Shape.isEmpty(jvmShapeInfo.javaShapeInformation); - } - - @Override - public long[] shapeInfoJava() { - return jvmShapeInfo.javaShapeInformation; - } - - @Override - public DataType dataType() { - if (data != null) - return data.dataType(); - - val e = Shape.extras(jvmShapeInfo.javaShapeInformation); - - if (e != 0) { - val t = ArrayOptionsHelper.dataType(jvmShapeInfo.javaShapeInformation); - return t; - } - - return DataType.UNKNOWN; - } - - @Override - public boolean isR() { - val dtype = dataType(); - return dtype == DataType.FLOAT || dtype == DataType.DOUBLE || dtype == DataType.HALF || dtype == DataType.BFLOAT16; - } - - @Override - public boolean isZ() { - return !isR() && !isB() && !isS(); - } - - @Override - public boolean isB() { - return dataType() == DataType.BOOL; - } - - @Override - public boolean isS() { - return dataType() == DataType.UTF8; - } - - @Override - public INDArray castTo(DataType dataType) { - if(dataType == dataType()) //No-op if correct datatype - return this; - if(isEmpty() && rank() == 0){ - return Nd4j.empty(dataType); - } - val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); - result.assign(this); - return result; - } - - @Override - public boolean all() { - val r = Nd4j.getExecutioner().exec(new All(this)); - return r.getDouble(0) != 0.0; - } - - @Override - public boolean any() { - val r = Nd4j.getExecutioner().exec(new Any(this)); - return r.getDouble(0) != 0.0; - } - - @Override - public boolean none() { - return !any(); - } - - - /** - * Validate that the operation is being applied on a numerical array (not boolean or utf8). - * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays - * @param opName Operation name to print in the exception - */ - protected void validateNumericalArray(String opName, boolean allowEmpty){ - if(dataType() == DataType.BOOL || dataType() == DataType.UTF8) - throw new IllegalStateException("Cannot apply operation " + opName + " to array with " + dataType() + " datatype. Array shape: " + Arrays.toString(shape())); - if(!allowEmpty && isEmpty()) - throw new IllegalStateException("Cannot perform operation " + opName + " on empty array with datatype " + dataType()); - } - - @Override - public boolean closeable() { - if (released || isAttached()) - return false; - - // empty arrays have no buffer at all - if (isEmpty()) - return true; - - if (isView()) - return false; - - return data.closeable(); - } - - @Override - public void close() { - // empty arrays have no buffer at all - if (released || isEmpty()) - return; + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + return Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + INDArray copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + copy.assign(this); Nd4j.getExecutioner().commit(); - if (!closeable()) - throw new ND4JIllegalStateException("Can't release this INDArray"); + return copy; + } + } else { + MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + Nd4j.getMemoryManager().setCurrentWorkspace(null); + INDArray copy = null; - data.close(); + if (!isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - released = true; + //Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType())); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, + this.shapeInfoDataBuffer()); //this.dup(this.ordering()); + + + } else { + copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + copy.assign(this); + Nd4j.getExecutioner().commit(); + } + + Nd4j.getMemoryManager().setCurrentWorkspace(workspace); + + return copy; + } + } + + @Override + public INDArray leverage() { + WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); + if (!isAttached()) { + return this; + } + + MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + if (workspace == null) { + return this.detach(); } - @Override - public INDArray like() { - return Nd4j.create(this.dataType(), this.shape(), Nd4j.getStrides(this.shape(), this.ordering()), this.ordering()); + MemoryWorkspace parentWorkspace = workspace.getParentWorkspace(); + + if (this.data.getParentWorkspace() == parentWorkspace) { + return this; + } + + // if there's no parent ws - just detach + if (parentWorkspace == null) { + return this.detach(); + } else { + Nd4j.getExecutioner().commit(); + + // temporary set parent ws as current ws + Nd4j.getMemoryManager().setCurrentWorkspace(parentWorkspace); + + INDArray copy = null; + if (!this.isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.length(), false); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + copy = this.dup(this.ordering()); + Nd4j.getExecutioner().commit(); + } + + // restore current ws + Nd4j.getMemoryManager().setCurrentWorkspace(workspace); + return copy; + } + } + + @Override + public INDArray leverageTo(String id) { + return leverageTo(id, false); + } + + @Override + public INDArray leverageTo(String id, boolean enforceExistence) + throws Nd4jNoSuchWorkspaceException { + WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); + if (!isAttached()) { + return this; + } + + if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) { + if (enforceExistence) { + throw new Nd4jNoSuchWorkspaceException(id); + } else { + return this; + } } - @Override - public INDArray ulike() { - return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); + MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id); + + if (this.data.getParentWorkspace() == target) { + return this; + } + + Nd4j.getMemoryManager().setCurrentWorkspace(target); + INDArray copy = null; + if (!this.isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + copy = this.dup(this.ordering()); + Nd4j.getExecutioner().commit(); } - @Override - public boolean wasClosed() { - // data can be null if that's empty array - return released || (data() != null && data().wasClosed()); + Nd4j.getMemoryManager().setCurrentWorkspace(current); + + return copy; + } + + public INDArray leverageOrDetach(String id) { + if (!isAttached()) { + return this; } - @Override - public long getId(){ - return arrayId; + if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(id)) { + return detach(); } + return leverageTo(id); + } + + @Override + public INDArray migrate() { + return migrate(false); + } + + @Override + public INDArray migrate(boolean detachOnNoWs) { + WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); + + MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); + + if (current == null) { + if (detachOnNoWs) { + return detach(); + } else { + return this; + } + } + + INDArray copy = null; + + if (!this.isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + copy = this.dup(this.ordering()); + Nd4j.getExecutioner().commit(); + } + + return copy; + } + + @Override + public Number percentileNumber(Number quantile) { + validateNumericalArray("percentileNumber", false); + if (quantile.intValue() < 0 || quantile.intValue() > 100) { + throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); + } + + if (isScalar()) { + return this.getDouble(0); + } + + INDArray sorted = Nd4j.sort(this.dup(this.ordering()), true); + + return getPercentile(quantile, sorted); + } + + @Override + public Number medianNumber() { + validateNumericalArray("medianNumber", false); + if (isScalar()) { + return getNumber(0); + } + return percentileNumber(50); + } + + @Override + public INDArray median(int... dimension) { + validateNumericalArray("median", false); + //Check edge case: size 1 element. No dimension == full array + if (dimension.length == 0) { + return Nd4j.scalar(dataType(), medianNumber().doubleValue()); + } + long shapeProd = 1; + for (int d : dimension) { + shapeProd *= size(d); + } + if (shapeProd == 1) { + long[] newShape = ArrayUtil.removeIndex(shape(), dimension); + return dup('c').reshape('c', newShape); + } + return percentile(50, dimension); + } + + protected double getPercentile(Number quantile, INDArray sorted) { + validateNumericalArray("getPercentile", false); + if (quantile.intValue() == 0) { + return sorted.getDouble(0); + } else if (quantile.intValue() == 100) { + return sorted.getDouble(sorted.length() - 1); + } + + double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); + if (pos < 1) { + return sorted.getDouble(0); + } else if (pos >= sorted.length()) { + return sorted.getDouble(sorted.length() - 1); + } + + double fposition = FastMath.floor(pos); + int position = (int) fposition; + + double diff = pos - fposition; + + double lower = sorted.getDouble(position - 1); + double upper = sorted.getDouble(position); + + return lower + diff * (upper - lower); + } + + @Override + public INDArray percentile(Number quantile, int... dimension) { + validateNumericalArray("percentile", false); + if (quantile.doubleValue() < 0 || quantile.doubleValue() > 100) { + throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); + } + + if (isScalar()) { + return Nd4j.scalar(this.getDouble(0)); + } + + INDArray sorted = Nd4j.getNDArrayFactory().sort(this.dup(this.ordering()), false, dimension); + + // there's no practical sense doing this on GPU, stride will be just size of TAD. + INDArray ret = Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), + sorted.tensorsAlongDimension(dimension)); + for (int i = 0; i < ret.length(); i++) { + ret.putScalar(i, getPercentile(quantile, sorted.tensorAlongDimension(i, dimension))); + } + + return ret; + + } + + protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer); + + @Override + public int toFlatArray(FlatBufferBuilder builder) { + if (isView()) { + return dup(this.ordering()).toFlatArray(builder); + } + int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong()); + int buffer = this.isEmpty() ? 0 + : this.dataType() == DataType.UTF8 ? stringBuffer(builder, this.data()) + : FlatArray.createBufferVector(builder, this.data().asBytes()); + val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(this.dataType()) + : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType()); + int array = FlatArray.createFlatArray(builder, shape, buffer, type, ByteOrder.BE); + + return array; + } + + protected static DataTypeEx convertType(DataType type) { + if (type == DataType.HALF) { + return DataTypeEx.FLOAT16; + } else if (type == DataType.FLOAT) { + return DataTypeEx.FLOAT; + } else if (type == DataType.DOUBLE) { + return DataTypeEx.DOUBLE; + + } else if (type == DataType.INT) { + return DataTypeEx.INT8; + } else if (type == DataType.LONG) { + return DataTypeEx.INT16; + + } else { + throw new IllegalStateException("Unknown dataType: [" + type + "]"); + } + } + + @Override + public boolean isEmpty() { + return Shape.isEmpty(jvmShapeInfo.javaShapeInformation); + } + + @Override + public long[] shapeInfoJava() { + return jvmShapeInfo.javaShapeInformation; + } + + @Override + public DataType dataType() { + if (data != null) { + return data.dataType(); + } + + val e = Shape.extras(jvmShapeInfo.javaShapeInformation); + + if (e != 0) { + val t = ArrayOptionsHelper.dataType(jvmShapeInfo.javaShapeInformation); + return t; + } + + return DataType.UNKNOWN; + } + + @Override + public boolean isR() { + val dtype = dataType(); + return dtype == DataType.FLOAT || dtype == DataType.DOUBLE || dtype == DataType.HALF + || dtype == DataType.BFLOAT16; + } + + @Override + public boolean isZ() { + return !isR() && !isB() && !isS(); + } + + @Override + public boolean isB() { + return dataType() == DataType.BOOL; + } + + @Override + public boolean isS() { + return dataType() == DataType.UTF8; + } + + @Override + public INDArray castTo(DataType dataType) { + if (dataType == dataType()) //No-op if correct datatype + { + return this; + } + if (isEmpty() && rank() == 0) { + return Nd4j.empty(dataType); + } + val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); + result.assign(this); + return result; + } + + @Override + public boolean all() { + val r = Nd4j.getExecutioner().exec(new All(this)); + return r.getDouble(0) != 0.0; + } + + @Override + public boolean any() { + val r = Nd4j.getExecutioner().exec(new Any(this)); + return r.getDouble(0) != 0.0; + } + + @Override + public boolean none() { + return !any(); + } + + + /** + * Validate that the operation is being applied on a numerical array (not boolean or utf8). Some + * operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 + * arrays + * + * @param opName Operation name to print in the exception + */ + protected void validateNumericalArray(String opName, boolean allowEmpty) { + if (dataType() == DataType.BOOL || dataType() == DataType.UTF8) { + throw new IllegalStateException( + "Cannot apply operation " + opName + " to array with " + dataType() + + " datatype. Array shape: " + Arrays.toString(shape())); + } + if (!allowEmpty && isEmpty()) { + throw new IllegalStateException( + "Cannot perform operation " + opName + " on empty array with datatype " + dataType()); + } + } + + @Override + public boolean closeable() { + if (released || isAttached()) { + return false; + } + + // empty arrays have no buffer at all + if (isEmpty()) { + return true; + } + + if (isView()) { + return false; + } + + return data.closeable(); + } + + @Override + public void close() { + // empty arrays have no buffer at all + if (released || isEmpty()) { + return; + } + + Nd4j.getExecutioner().commit(); + + if (!closeable()) { + throw new ND4JIllegalStateException("Can't release this INDArray"); + } + + data.close(); + + released = true; + } + + @Override + public INDArray like() { + return Nd4j.create(this.dataType(), this.shape(), + Nd4j.getStrides(this.shape(), this.ordering()), this.ordering()); + } + + @Override + public INDArray ulike() { + return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + } + + @Override + public boolean wasClosed() { + // data can be null if that's empty array + return released || (data() != null && data().wasClosed()); + } + + @Override + public long getId() { + return arrayId; + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java index 98a8f95ab..aeb91f0b7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java @@ -47,7 +47,7 @@ public abstract class BaseShapeInfoProvider implements ShapeInfoProvider { } /** - * This method creates shapeInformation buffer, based on shape & order being passed in + * This method creates shapeInformation buffer, based on shape and order being passed in * * @param shape * @param order diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 61b53e23d..f4d4b200e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -2216,7 +2216,7 @@ public interface INDArray extends Serializable, AutoCloseable { * Dimshuffle: an extension of permute that adds the ability * to broadcast various dimensions. * This will only accept integers and xs. - *

      + *

      * An x indicates a dimension should be broadcasted rather than permuted. * * Examples originally from the theano docs: @@ -2226,15 +2226,15 @@ public interface INDArray extends Serializable, AutoCloseable { A few examples of patterns and their effect: - ('x') -> make a 0d (scalar) into a 1d vector - (0, 1) -> identity for 2d vectors - (1, 0) -> inverts the first and second dimensions - ('x', 0) -> make a row out of a 1d vector (N to 1xN) - (0, 'x') -> make a column out of a 1d vector (N to Nx1) - (2, 0, 1) -> AxBxC to CxAxB - (0, 'x', 1) -> AxB to Ax1xB - (1, 'x', 0) -> AxB to Bx1xA - (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) + ('x') -> make a 0d (scalar) into a 1d vector + (0, 1) -> identity for 2d vectors + (1, 0) -> inverts the first and second dimensions + ('x', 0) -> make a row out of a 1d vector (N to 1xN) + (0, 'x') -> make a column out of a 1d vector (N to Nx1) + (2, 0, 1) -> AxBxC to CxAxB + (0, 'x', 1) -> AxB to Ax1xB + (1, 'x', 0) -> AxB to Bx1xA + (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) * @param rearrange the dimensions to swap to * @param newOrder the new order (think permute) @@ -2244,7 +2244,7 @@ public interface INDArray extends Serializable, AutoCloseable { INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable); /** - * See {@link #dimShuffle(Object[], int[], boolean[]) + * See {@link #dimShuffle(Object[], int[], boolean[])} */ INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java index 1f3768038..33f8378f1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java @@ -33,7 +33,7 @@ public interface ShapeInfoProvider { Pair createShapeInformation(long[] shape, DataType dataType); /** - * This method creates long shapeInformation buffer, based on shape & order being passed in + * This method creates long shapeInformation buffer, based on shape and order being passed in * @param shape * @return */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index d258e4b3a..a54f4700a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -65,7 +65,8 @@ public interface OpContext extends AutoCloseable { /** * This method sets root-level seed for rng - * @param seed + * @param rootState + * @param nodeState */ void setRngStates(long rootState, long nodeState); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java index d3740509c..30e349e94 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java @@ -251,7 +251,7 @@ public interface Random extends AutoCloseable { * The reason for this is due to ints * having the same space usage as floats. * This also plays nice with blas. - *

      + *

      * If the data opType is set to double, * then these will be whole doubles. * @@ -272,7 +272,7 @@ public interface Random extends AutoCloseable { * The reason for this is due to ints * having the same space usage as floats. * This also plays nice with blas. - *

      + *

      * If the data opType is set to double, * then these will be whole doubles. * diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java index 9015301ea..801eb8d89 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java @@ -35,233 +35,236 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.Iterator; public abstract class BaseDistribution implements Distribution { - protected Random random; - protected double solverAbsoluteAccuracy; + + protected Random random; + protected double solverAbsoluteAccuracy; - public BaseDistribution(Random rng) { - this.random = rng; + public BaseDistribution(Random rng) { + this.random = rng; + } + + + public BaseDistribution() { + this(Nd4j.getRandom()); + } + + /** + * For a random variable {@code X} whose values are distributed according to this distribution, + * this method returns {@code P(x0 < X <= x1)}. + * + * @param x0 Lower bound (excluded). + * @param x1 Upper bound (included). + * @return the probability that a random variable with this distribution takes a value between + * {@code x0} and {@code x1}, excluding the lower and including the upper endpoint. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}. + *

      + * The default implementation + * uses the identity + * {@code P(x0 < X <= x1) = + * P(X <= x1) - P(X <= x0)} + * @since 3.1 + */ + + public double probability(double x0, double x1) { + if (x0 > x1) { + throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, + x1, true); } + return cumulativeProbability(x1) - cumulativeProbability(x0); + } - - public BaseDistribution() { - this(Nd4j.getRandom()); - } - - /** - * For a random variable {@code X} whose values are distributed according - * to this distribution, this method returns {@code P(x0 < X <= x1)}. + /** + * {@inheritDoc} + *

      + * The default implementation returns + *

        + *
      • {@link #getSupportLowerBound()} for {@code p = 0},
      • + *
      • {@link #getSupportUpperBound()} for {@code p = 1}.
      • + *
      + */ + @Override + public double inverseCumulativeProbability(final double p) throws OutOfRangeException { + /* + * IMPLEMENTATION NOTES + * -------------------- + * Where applicable, use is made of the one-sided Chebyshev inequality + * to bracket the root. This inequality states that + * P(X - mu >= k * sig) <= 1 / (1 + k^2), + * mu: mean, sig: standard deviation. Equivalently + * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2), + * F(mu + k * sig) >= k^2 / (1 + k^2). * - * @param x0 Lower bound (excluded). - * @param x1 Upper bound (included). - * @return the probability that a random variable with this distribution - * takes a value between {@code x0} and {@code x1}, excluding the lower - * and including the upper endpoint. - * @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}. - *

      - * The default implementation uses the identity - * {@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)} - * @since 3.1 + * For k = sqrt(p / (1 - p)), we find + * F(mu + k * sig) >= p, + * and (mu + k * sig) is an upper-bound for the root. + * + * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and + * P(Y >= -mu + k * sig) <= 1 / (1 + k^2), + * P(-X >= -mu + k * sig) <= 1 / (1 + k^2), + * P(X <= mu - k * sig) <= 1 / (1 + k^2), + * F(mu - k * sig) <= 1 / (1 + k^2). + * + * For k = sqrt((1 - p) / p), we find + * F(mu - k * sig) <= p, + * and (mu - k * sig) is a lower-bound for the root. + * + * In cases where the Chebyshev inequality does not apply, geometric + * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket + * the root. */ - - public double probability(double x0, double x1) { - if (x0 > x1) { - throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, x1, true); - } - return cumulativeProbability(x1) - cumulativeProbability(x0); + if (p < 0.0 || p > 1.0) { + throw new OutOfRangeException(p, 0, 1); } - /** - * {@inheritDoc} - *

      - * The default implementation returns - *

        - *
      • {@link #getSupportLowerBound()} for {@code p = 0},
      • - *
      • {@link #getSupportUpperBound()} for {@code p = 1}.
      • - *
      - */ - @Override - public double inverseCumulativeProbability(final double p) throws OutOfRangeException { - /* - * IMPLEMENTATION NOTES - * -------------------- - * Where applicable, use is made of the one-sided Chebyshev inequality - * to bracket the root. This inequality states that - * P(X - mu >= k * sig) <= 1 / (1 + k^2), - * mu: mean, sig: standard deviation. Equivalently - * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2), - * F(mu + k * sig) >= k^2 / (1 + k^2). - * - * For k = sqrt(p / (1 - p)), we find - * F(mu + k * sig) >= p, - * and (mu + k * sig) is an upper-bound for the root. - * - * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and - * P(Y >= -mu + k * sig) <= 1 / (1 + k^2), - * P(-X >= -mu + k * sig) <= 1 / (1 + k^2), - * P(X <= mu - k * sig) <= 1 / (1 + k^2), - * F(mu - k * sig) <= 1 / (1 + k^2). - * - * For k = sqrt((1 - p) / p), we find - * F(mu - k * sig) <= p, - * and (mu - k * sig) is a lower-bound for the root. - * - * In cases where the Chebyshev inequality does not apply, geometric - * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket - * the root. - */ - if (p < 0.0 || p > 1.0) { - throw new OutOfRangeException(p, 0, 1); + double lowerBound = getSupportLowerBound(); + if (p == 0.0) { + return lowerBound; + } + + double upperBound = getSupportUpperBound(); + if (p == 1.0) { + return upperBound; + } + + final double mu = getNumericalMean(); + final double sig = FastMath.sqrt(getNumericalVariance()); + final boolean chebyshevApplies; + chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig) + || Double.isNaN(sig)); + + if (lowerBound == Double.NEGATIVE_INFINITY) { + if (chebyshevApplies) { + lowerBound = mu - sig * FastMath.sqrt((1. - p) / p); + } else { + lowerBound = -1.0; + while (cumulativeProbability(lowerBound) >= p) { + lowerBound *= 2.0; } + } + } - double lowerBound = getSupportLowerBound(); - if (p == 0.0) { - return lowerBound; + if (upperBound == Double.POSITIVE_INFINITY) { + if (chebyshevApplies) { + upperBound = mu + sig * FastMath.sqrt(p / (1. - p)); + } else { + upperBound = 1.0; + while (cumulativeProbability(upperBound) < p) { + upperBound *= 2.0; } + } + } - double upperBound = getSupportUpperBound(); - if (p == 1.0) { - return upperBound; - } + final UnivariateFunction toSolve = new UnivariateFunction() { - final double mu = getNumericalMean(); - final double sig = FastMath.sqrt(getNumericalVariance()); - final boolean chebyshevApplies; - chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig) || Double.isNaN(sig)); + public double value(final double x) { + return cumulativeProbability(x) - p; + } + }; - if (lowerBound == Double.NEGATIVE_INFINITY) { - if (chebyshevApplies) { - lowerBound = mu - sig * FastMath.sqrt((1. - p) / p); + double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound, + getSolverAbsoluteAccuracy()); + + if (!isSupportConnected()) { + /* Test for plateau. */ + final double dx = getSolverAbsoluteAccuracy(); + if (x - dx >= getSupportLowerBound()) { + double px = cumulativeProbability(x); + if (cumulativeProbability(x - dx) == px) { + upperBound = x; + while (upperBound - lowerBound > dx) { + final double midPoint = 0.5 * (lowerBound + upperBound); + if (cumulativeProbability(midPoint) < px) { + lowerBound = midPoint; } else { - lowerBound = -1.0; - while (cumulativeProbability(lowerBound) >= p) { - lowerBound *= 2.0; - } + upperBound = midPoint; } + } + return upperBound; } - - if (upperBound == Double.POSITIVE_INFINITY) { - if (chebyshevApplies) { - upperBound = mu + sig * FastMath.sqrt(p / (1. - p)); - } else { - upperBound = 1.0; - while (cumulativeProbability(upperBound) < p) { - upperBound *= 2.0; - } - } - } - - final UnivariateFunction toSolve = new UnivariateFunction() { - - public double value(final double x) { - return cumulativeProbability(x) - p; - } - }; - - double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound, getSolverAbsoluteAccuracy()); - - if (!isSupportConnected()) { - /* Test for plateau. */ - final double dx = getSolverAbsoluteAccuracy(); - if (x - dx >= getSupportLowerBound()) { - double px = cumulativeProbability(x); - if (cumulativeProbability(x - dx) == px) { - upperBound = x; - while (upperBound - lowerBound > dx) { - final double midPoint = 0.5 * (lowerBound + upperBound); - if (cumulativeProbability(midPoint) < px) { - lowerBound = midPoint; - } else { - upperBound = midPoint; - } - } - return upperBound; - } - } - } - return x; + } } + return x; + } - /** - * Returns the solver absolute accuracy for inverse cumulative computation. - * You can override this method in order to use a Brent solver with an - * absolute accuracy different from the default. - * - * @return the maximum absolute error in inverse cumulative probability estimates - */ - protected double getSolverAbsoluteAccuracy() { - return solverAbsoluteAccuracy; - } + /** + * Returns the solver absolute accuracy for inverse cumulative computation. You can override this + * method in order to use a Brent solver with an absolute accuracy different from the default. + * + * @return the maximum absolute error in inverse cumulative probability estimates + */ + protected double getSolverAbsoluteAccuracy() { + return solverAbsoluteAccuracy; + } - /** - * {@inheritDoc} - */ - @Override - public void reseedRandomGenerator(long seed) { - random.setSeed(seed); - } + /** + * {@inheritDoc} + */ + @Override + public void reseedRandomGenerator(long seed) { + random.setSeed(seed); + } - /** - * {@inheritDoc} - *

      - * The default implementation uses the - * - * inversion method. - * - */ - @Override - public double sample() { - return inverseCumulativeProbability(random.nextDouble()); - } + /** + * {@inheritDoc} + * The default implementation uses the + * + * inversion method. + * + */ + @Override + public double sample() { + return inverseCumulativeProbability(random.nextDouble()); + } - /** - * {@inheritDoc} - *

      - * The default implementation generates the sample by calling - * {@link #sample()} in a loop. - */ - @Override - public double[] sample(long sampleSize) { - if (sampleSize <= 0) { - throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); - } - double[] out = new double[(int) sampleSize]; - for (int i = 0; i < sampleSize; i++) { - out[i] = sample(); - } - return out; + /** + * {@inheritDoc} + *

      + * The default implementation generates the sample by calling {@link #sample()} in a loop. + */ + @Override + public double[] sample(long sampleSize) { + if (sampleSize <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); } + double[] out = new double[(int) sampleSize]; + for (int i = 0; i < sampleSize; i++) { + out[i] = sample(); + } + return out; + } - /** - * {@inheritDoc} - * - * @return zero. - * @since 3.1 - */ - @Override - public double probability(double x) { - return 0d; - } + /** + * {@inheritDoc} + * + * @return zero. + * @since 3.1 + */ + @Override + public double probability(double x) { + return 0d; + } - @Override - public INDArray sample(int[] shape) { - INDArray ret = Nd4j.create(shape); - return sample(ret); - } + @Override + public INDArray sample(int[] shape) { + INDArray ret = Nd4j.create(shape); + return sample(ret); + } - @Override - public INDArray sample(long[] shape) { - INDArray ret = Nd4j.create(shape); - return sample(ret); - } + @Override + public INDArray sample(long[] shape) { + INDArray ret = Nd4j.create(shape); + return sample(ret); + } - @Override - public INDArray sample(INDArray target) { - Iterator idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering - long len = target.length(); - for (long i = 0; i < len; i++) { - target.putScalar(idxIter.next(), sample()); - } - return target; + @Override + public INDArray sample(INDArray target) { + Iterator idxIter = new NdIndexIterator( + target.shape()); //For consistent values irrespective of c vs. fortran ordering + long len = target.length(); + for (long i = 0; i < len; i++) { + target.putScalar(idxIter.next(), sample()); } + return target; + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java index e224f5866..3375d57df 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java @@ -89,8 +89,8 @@ public interface Distribution { * variable {@code X} distributed according to this distribution, the * returned value is *

        - *
      • inf{x in R | P(X<=x) >= p} for {@code 0 < p <= 1},
      • - *
      • inf{x in R | P(X<=x) > 0} for {@code p = 0}.
      • + *
      • {@code inf{x in R | P(X<=x) >= p}} for {@code 0 < p <= 1},
      • + *
      • {@code inf{x in R | P(X<=x) > 0}} for {@code p = 0}.
      • *
      * * @param p the cumulative probability @@ -122,7 +122,7 @@ public interface Distribution { * Access the lower bound of the support. This method must return the same * value as {@code inverseCumulativeProbability(0)}. In other words, this * method must return - *

      inf {x in R | P(X <= x) > 0}.

      + *

      {@code inf {x in R | P(X <= x) > 0}}.

      * * @return lower bound of the support (might be * {@code Double.NEGATIVE_INFINITY}) @@ -133,7 +133,7 @@ public interface Distribution { * Access the upper bound of the support. This method must return the same * value as {@code inverseCumulativeProbability(1)}. In other words, this * method must return - *

      inf {x in R | P(X <= x) = 1}.

      + *

      {@code inf {x in R | P(X <= x) = 1}}.

      * * @return upper bound of the support (might be * {@code Double.POSITIVE_INFINITY}) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java index 2d295d53f..2d50a779d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java @@ -166,7 +166,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For {@code n} trials and probability parameter {@code p}, the mean is * {@code n * p}. */ @@ -177,7 +177,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For {@code n} trials and probability parameter {@code p}, the variance is * {@code n * p * (1 - p)}. */ @@ -189,7 +189,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is always 0 except for the probability * parameter {@code p = 1}. * @@ -203,7 +203,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is the number of trials except for the * probability parameter {@code p = 0}. * @@ -227,7 +227,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java index b56722c30..cb4136829 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java @@ -83,7 +83,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -131,7 +131,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -140,7 +140,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -150,7 +150,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -163,7 +163,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -190,7 +190,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java index f0c9aa396..8788539fd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java @@ -172,7 +172,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -238,7 +237,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -247,7 +245,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -257,7 +254,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -270,7 +266,7 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -297,7 +293,7 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java index a7ccc5caf..6cb7b5995 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java @@ -176,7 +176,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -242,7 +241,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -251,7 +249,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -261,7 +258,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -274,7 +270,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -301,7 +296,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java index 3b1faaf71..455388aa8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java @@ -34,27 +34,28 @@ import org.nd4j.common.util.ArrayUtil; @Slf4j public class OrthogonalDistribution extends BaseDistribution { - /** - * Default inverse cumulative probability accuracy. - * - * @since 2.1 - */ - public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9; - /** - * Serializable version identifier. - */ - private static final long serialVersionUID = 8589540077390120676L; - /** - * Mean of this distribution. - */ - private final double gain; - private INDArray gains; + /** + * Default inverse cumulative probability accuracy. + * + * @since 2.1 + */ + public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9; + /** + * Serializable version identifier. + */ + private static final long serialVersionUID = 8589540077390120676L; - public OrthogonalDistribution(double gain) { - this.gain = gain; - this.random = Nd4j.getRandom(); - } + /** + * Mean of this distribution. + */ + private final double gain; + private INDArray gains; + + public OrthogonalDistribution(double gain) { + this.gain = gain; + this.random = Nd4j.getRandom(); + } /* max doesn't want this distripution public OrthogonalDistribution(@NonNull INDArray gains) { @@ -62,196 +63,192 @@ public class OrthogonalDistribution extends BaseDistribution { this.random = Nd4j.getRandom(); } */ - /** - * Access the mean. - * - * @return the mean for this distribution. - */ - public double getMean() { - throw new UnsupportedOperationException(); + + /** + * Access the mean. + * + * @return the mean for this distribution. + */ + public double getMean() { + throw new UnsupportedOperationException(); + } + + /** + * Access the standard deviation. + * + * @return the standard deviation for this distribution. + */ + public double getStandardDeviation() { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + */ + public double density(double x) { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} If {@code x} is more than 40 standard deviations from the mean, 0 or 1 is + * returned, as in these cases the actual value is within {@code Double.MIN_VALUE} of 0 or 1. + */ + public double cumulativeProbability(double x) { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + * + * @since 3.2 + */ + @Override + public double inverseCumulativeProbability(final double p) throws OutOfRangeException { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + * + * @deprecated See + * {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double, + * double)} + */ + @Override + @Deprecated + public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + */ + @Override + public double probability(double x0, double x1) throws NumberIsTooLargeException { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + */ + @Override + protected double getSolverAbsoluteAccuracy() { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} For mean parameter {@code mu}, the mean is {@code mu}. + */ + public double getNumericalMean() { + return getMean(); + } + + /** + * {@inheritDoc} For standard deviation parameter {@code s}, the variance is {@code s^2}. + */ + public double getNumericalVariance() { + final double s = getStandardDeviation(); + return s * s; + } + + /** + * {@inheritDoc} The lower bound of the support is always negative infinity no matter the + * parameters. + * + * @return lower bound of the support (always {@code Double.NEGATIVE_INFINITY}) + */ + public double getSupportLowerBound() { + return Double.NEGATIVE_INFINITY; + } + + /** + * {@inheritDoc} + *

      + * The upper bound of the support is always positive infinity no matter the parameters. + * + * @return upper bound of the support (always {@code Double.POSITIVE_INFINITY}) + */ + public double getSupportUpperBound() { + return Double.POSITIVE_INFINITY; + } + + /** + * {@inheritDoc} + */ + public boolean isSupportLowerBoundInclusive() { + return false; + } + + /** + * {@inheritDoc} + */ + public boolean isSupportUpperBoundInclusive() { + return false; + } + + /** + * {@inheritDoc} + *

      + * The support of this distribution is connected. + * + * @return {@code true} + */ + public boolean isSupportConnected() { + return true; + } + + /** + * {@inheritDoc} + */ + @Override + public double sample() { + throw new UnsupportedOperationException(); + } + + @Override + public INDArray sample(int[] shape) { + return sample(ArrayUtil.toLongArray(shape)); + } + + @Override + public INDArray sample(long[] shape) { + long numRows = 1; + for (int i = 0; i < shape.length - 1; i++) { + numRows *= shape[i]; } + long numCols = shape[shape.length - 1]; - /** - * Access the standard deviation. - * - * @return the standard deviation for this distribution. - */ - public double getStandardDeviation() { - throw new UnsupportedOperationException(); + val dtype = Nd4j.defaultFloatingPointType(); + + val flatShape = new long[]{numRows, numCols}; + val flatRng = Nd4j.getExecutioner().exec( + new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, + 1.0), random); + + val m = flatRng.rows(); + val n = flatRng.columns(); + + val s = Nd4j.create(dtype, m < n ? m : n); + val u = Nd4j.create(dtype, m, m); + val v = Nd4j.create(dtype, new long[]{n, n}, 'f'); + + Nd4j.exec(new Svd(flatRng, true, s, u, v)); + + if (gains == null) { + if (u.rows() >= numRows && u.columns() >= numCols) { + return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain) + .reshape(shape); + } else { + return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain) + .reshape(shape); + } + } else { + throw new UnsupportedOperationException(); } + } - /** - * {@inheritDoc} - */ - public double density(double x) { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - *

      - * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 - * is returned, as in these cases the actual value is within - * {@code Double.MIN_VALUE} of 0 or 1. - */ - public double cumulativeProbability(double x) { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - * - * @since 3.2 - */ - @Override - public double inverseCumulativeProbability(final double p) throws OutOfRangeException { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - * - * @deprecated See {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double, double)} - */ - @Override - @Deprecated - public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - */ - @Override - public double probability(double x0, double x1) throws NumberIsTooLargeException { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - */ - @Override - protected double getSolverAbsoluteAccuracy() { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - *

      - * For mean parameter {@code mu}, the mean is {@code mu}. - */ - public double getNumericalMean() { - return getMean(); - } - - /** - * {@inheritDoc} - *

      - * For standard deviation parameter {@code s}, the variance is {@code s^2}. - */ - public double getNumericalVariance() { - final double s = getStandardDeviation(); - return s * s; - } - - /** - * {@inheritDoc} - *

      - * The lower bound of the support is always negative infinity - * no matter the parameters. - * - * @return lower bound of the support (always - * {@code Double.NEGATIVE_INFINITY}) - */ - public double getSupportLowerBound() { - return Double.NEGATIVE_INFINITY; - } - - /** - * {@inheritDoc} - *

      - * The upper bound of the support is always positive infinity - * no matter the parameters. - * - * @return upper bound of the support (always - * {@code Double.POSITIVE_INFINITY}) - */ - public double getSupportUpperBound() { - return Double.POSITIVE_INFINITY; - } - - /** - * {@inheritDoc} - */ - public boolean isSupportLowerBoundInclusive() { - return false; - } - - /** - * {@inheritDoc} - */ - public boolean isSupportUpperBoundInclusive() { - return false; - } - - /** - * {@inheritDoc} - *

      - * The support of this distribution is connected. - * - * @return {@code true} - */ - public boolean isSupportConnected() { - return true; - } - - /** - * {@inheritDoc} - */ - @Override - public double sample() { - throw new UnsupportedOperationException(); - } - - @Override - public INDArray sample(int[] shape) { - return sample(ArrayUtil.toLongArray(shape)); - } - - @Override - public INDArray sample(long[] shape){ - long numRows = 1; - for (int i = 0; i < shape.length - 1; i++) - numRows *= shape[i]; - long numCols = shape[shape.length - 1]; - - val dtype = Nd4j.defaultFloatingPointType(); - - val flatShape = new long[]{numRows, numCols}; - val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random); - - val m = flatRng.rows(); - val n = flatRng.columns(); - - val s = Nd4j.create(dtype, m < n ? m : n); - val u = Nd4j.create(dtype, m, m); - val v = Nd4j.create(dtype, new long[] {n, n}, 'f'); - - Nd4j.exec(new Svd(flatRng, true, s, u, v)); - - if (gains == null) { - if (u.rows() >= numRows && u.columns() >= numCols) { - return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); - } else { - return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); - } - } else { - throw new UnsupportedOperationException(); - } - } - - @Override - public INDArray sample(INDArray target){ - return target.assign(sample(target.shape())); - } + @Override + public INDArray sample(INDArray target) { + return target.assign(sample(target.shape())); + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java index 446c0c264..6b547e091 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java @@ -84,7 +84,6 @@ public class SaddlePointExpansion { * href="http://mathworld.wolfram.com/StirlingsSeries.html"> * http://mathworld.wolfram.com/StirlingsSeries.html * - *

      * * @param z the value. * @return the Striling's series error. @@ -117,7 +116,6 @@ public class SaddlePointExpansion { * href="http://www.herine.net/stat/papers/dbinom.pdf"> * http://www.herine.net/stat/papers/dbinom.pdf * - *

      * * @param x the x value. * @param mu the average. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java index 3043c9ebf..75cb216c4 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java @@ -172,7 +172,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -238,7 +238,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -247,7 +247,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -257,7 +257,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -270,7 +270,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -297,7 +297,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java index 07627f05c..bd5e1635e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java @@ -105,7 +105,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For lower bound {@code lower} and upper bound {@code upper}, the mean is * {@code 0.5 * (lower + upper)}. */ @@ -115,7 +115,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For lower bound {@code lower} and upper bound {@code upper}, the * variance is {@code (upper - lower)^2 / 12}. */ @@ -126,7 +126,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is equal to the lower bound parameter * of the distribution. * @@ -138,7 +138,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is equal to the upper bound parameter * of the distribution. * @@ -164,7 +164,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java index 1665ca165..0e0d5c4b0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java @@ -58,7 +58,7 @@ public class NDArrayCreationUtil { } - /** Get an array of INDArrays (2d) all with the specified shape. Pair returned to aid + /** Get an array of INDArrays (2d) all with the specified shape. {@code Pair} returned to aid * debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments) * Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension, * etc to an original array. @@ -88,7 +88,7 @@ public class NDArrayCreationUtil { * eg. rank 2: 1,1; 1,2; 2,1; 2,2; 3,4 * Motivated by TADs that often hit bugs when a "1" occurs as the size of a dimension * - * @param rank any rank including true scalars i.e rank >= 0 + * @param rank any rank including true scalars i.e rank >= 0 * @param order what order array to return i.e 'c' or 'f' order arrays * @return List of arrays and the shapes as strings */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java index e57f29072..b05689764 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java @@ -355,7 +355,7 @@ public class AsyncDataSetIterator implements DataSetIterator { * yet been called, or the {@code remove} method has already * been called after the last call to the {@code next} * method - * @implSpec The default implementation throws an instance of + * The default implementation throws an instance of * {@link UnsupportedOperationException} and performs no other action. */ @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java index 822fa3ce2..5d372309b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java @@ -299,7 +299,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { * yet been called, or the {@code remove} method has already * been called after the last call to the {@code next} * method - * @implSpec The default implementation throws an instance of + * The default implementation throws an instance of * {@link UnsupportedOperationException} and performs no other action. */ @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java index f66afa29f..222990cc5 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java @@ -560,7 +560,6 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet { /** - * @Deprecated * Subtract by the column means and divide by the standard deviation */ @Deprecated diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java index 00e81c22f..6af1d6bcd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java @@ -117,7 +117,6 @@ public class KFoldIterator implements DataSetIterator { /** * Shuffles the dataset and resets to the first fold * - * @return void */ @Override public void reset() { @@ -129,7 +128,7 @@ public class KFoldIterator implements DataSetIterator { /** * The number of examples in every fold is (N / k), - * except when (N % k) > 0, when the first (N % k) folds contain (N / k) + 1 examples + * except when (N % k) > 0, when the first (N % k) folds contain (N / k) + 1 examples * * @return examples in a fold */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java index ff82d068c..1ebe44dca 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java @@ -49,7 +49,6 @@ public class TestDataSetIterator implements DataSetIterator { * Initializes with a default batch of 5 * * @param dataset the dataset to make the iterator from - * @param batch the batchsize for the iterator */ public TestDataSetIterator(DataSet dataset) { this(dataset, 5); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java index 326b9d45f..6ec22e018 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java @@ -65,9 +65,9 @@ public class RandomProjection { * The minimum number n' of components to guarantee the eps-embedding is * given by: * - * n' >= 4 log(n) / (eps² / 2 - eps³ / 3) + * {@code n' >= 4 log(n) / (eps² / 2 - eps³ / 3)} * - * see http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1 + * http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1 * @param n Number of samples. If an array is given, it will compute * a safe number of components array-wise. * @param eps Maximum distortion rate as defined by the Johnson-Lindenstrauss lemma. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java index e369b61e2..ee9a94719 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java @@ -30,7 +30,6 @@ public interface EnvironmentalAction { /** * This method will be executed with corresponding Env Var value * - * @param name * @param value */ void process(String value); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 3458ed06b..d4d11fa50 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -276,7 +276,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { * Rotate a matrix 90 degrees * * @param toRotate the matrix to rotate - * @return the rotated matrix */ @Override public void rot90(INDArray toRotate) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java index b4655fd6f..1965883af 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java @@ -33,7 +33,7 @@ public interface BlasWrapper { */ /** - * Compute x <-> y (swap two matrices) + * Compute {@code x <-> y} (swap two matrices) */ INDArray swap(INDArray x, INDArray y); @@ -69,14 +69,14 @@ public interface BlasWrapper { INDArray scal(double alpha, INDArray x); /** - * Compute x <- alpha * x (scale a matrix) + * Compute {@code x <- alpha * x} (scale a matrix) */ @Deprecated INDArray scal(float alpha, INDArray x); /** - * Compute y <- x (copy a matrix) + * Compute {@code y <- x} (copy a matrix) */ INDArray copy(INDArray x, INDArray y); @@ -84,13 +84,13 @@ public interface BlasWrapper { INDArray axpy(double da, INDArray dx, INDArray dy); /** - * Compute y <- alpha * x + y (elementwise addition) + * Compute {@code y <- alpha * x + y }(elementwise addition) */ @Deprecated INDArray axpy(float da, INDArray dx, INDArray dy); /** - * Compute y <- y + x * alpha + * Compute {@code y <- y + x * alpha} * @param da the alpha to multiply by * @param dx * @param dy @@ -130,7 +130,7 @@ public interface BlasWrapper { INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y); /** - * Compute y <- alpha*op(a)*x + beta * y (general matrix vector + * Compute {@code y <- alpha*op(a)*x + beta * y} (general matrix vector * multiplication) */ @Deprecated @@ -142,7 +142,7 @@ public interface BlasWrapper { INDArray ger(double alpha, INDArray x, INDArray y, INDArray a); /** - * Compute A <- alpha * x * y^T + A (general rank-1 update) + * Compute {@code A <- alpha * x * y^T + A} (general rank-1 update) */ INDArray ger(float alpha, INDArray x, INDArray y, INDArray a); @@ -193,14 +193,14 @@ public interface BlasWrapper { /** * Generalized Least Squares via *GELSD. - *

      + *

      * Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows * than columns. - *

      - * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain + *

      + * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain * the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix. - *

      - * Likewise, if m > n, the solution consists only of the first n rows of B. + *

      + * Likewise, if m > n, the solution consists only of the first n rows of B. * * @param A an (m,n) matrix * @param B an (max(m,n), k) matrix (well, at least) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java index fe162c8e2..a2b91fb15 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java @@ -193,7 +193,6 @@ public interface NDArrayFactory { * Rotate a matrix 90 degrees * * @param toRotate the matrix to rotate - * @return the rotated matrix */ void rot90(INDArray toRotate); @@ -340,7 +339,6 @@ public interface NDArrayFactory { * * @param array the ndarray to shuffle * @param dimension the dimension to do the shuffle - * @return */ void shuffle(INDArray array, Random rnd, int... dimension); @@ -350,7 +348,6 @@ public interface NDArrayFactory { * * @param array the ndarray to shuffle * @param dimension the dimension to do the shuffle - * @return */ void shuffle(Collection array, Random rnd, int... dimension); @@ -360,7 +357,6 @@ public interface NDArrayFactory { * * @param array the ndarray to shuffle * @param dimensions the dimensions to do the shuffle - * @return */ void shuffle(List array, Random rnd, List dimensions); @@ -1370,9 +1366,9 @@ public interface NDArrayFactory { INDArray createFromNpyFile(File file); /** - * Create a Map from given npz file. + * Create a {@code Map} from given npz file. * @param file the file to create the map from - * @return Map + * @return {@code Map} */ Map createFromNpzFile(File file) throws Exception; @@ -1386,7 +1382,7 @@ public interface NDArrayFactory { * * * @param array the array to convert - * @returnthe created pointer representing + * @return the created pointer representing * a pointer to a numpy header */ Pointer convertToNumpy(INDArray array); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 2dfff5fbd..f542e3cce 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -1441,7 +1441,7 @@ public class Nd4j { } /** - * See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype. + * See {@link #createBuffer(DataType dataType, long length, boolean initialize)} with default datatype. */ public static DataBuffer createBuffer(long length, boolean initialize) { return createBuffer(Nd4j.dataType(), length, initialize); @@ -2828,7 +2828,7 @@ public class Nd4j { } /** - * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)) + * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)} */ @Deprecated public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) { @@ -3306,7 +3306,7 @@ public class Nd4j { * Generate an array with random values generated according to a binomial distribution with the specified * number of trials and probability * - * @param nTrials Number of trials. Must be >= 0 + * @param nTrials Number of trials. Must be >= 0 * @param p Probability. Must be in range 0 to 1 * @param shape Shape of the result array * @return Result array @@ -3319,7 +3319,7 @@ public class Nd4j { * Fill the target array with random values generated according to a binomial distribution with the specified * number of trials and probability * - * @param nTrials Number of trials. Must be >= 0 + * @param nTrials Number of trials. Must be >= 0 * @param p Probability. Must be in range 0 to 1 * @param target Result array * @return Result array @@ -3333,7 +3333,7 @@ public class Nd4j { /** * Exponential distribution: P(x) = lambda * exp(-lambda * x) * - * @param lambda Must be > 0 + * @param lambda Must be > 0 * @param shape Shape of the array to generate */ public static INDArray randomExponential(double lambda, long... shape) { @@ -3341,9 +3341,9 @@ public class Nd4j { } /** - * Exponential distribution: P(x) = lambda * exp(-lambda * x) + * Exponential distribution: {@code P(x) = lambda * exp(-lambda * x)} * - * @param lambda Must be > 0 + * @param lambda Must be > 0 * @param target Array to hold the result */ public static INDArray randomExponential(double lambda, INDArray target) { @@ -3925,7 +3925,7 @@ public class Nd4j { } /** - * See {@link @see #create(int, int, int[], char)} + * See {@link Nd4j#create(int, int, int[], char)} */ public static INDArray zeros(int rows, int columns, int[] stride) { return create(rows, columns, stride, order()); @@ -4630,7 +4630,7 @@ public class Nd4j { /** * Concatenates two matrices vertically. Matrices must have identical numbers of columns.
      - * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] + * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] * * @param arrs Arrays to vstack */ @@ -4646,7 +4646,7 @@ public class Nd4j { /** * Concatenates two matrices vertically. Matrices must have identical numbers of columns.
      - * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] + * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] * * @param arrs Arrays to vstack */ @@ -5462,7 +5462,7 @@ public class Nd4j { Examples -------- - >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) + {@code >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) array([[ 1, 2, 3], [ 4, 5, 6], [ 0, 8, 9], @@ -5473,6 +5473,7 @@ public class Nd4j { mask = tri(*m.shape[-2:], k=k-1, dtype=bool) return where(mask, zeros(1, m.dtype), m) + } * @param m source array * @param k to zero below the k-th diagonal @@ -5517,8 +5518,8 @@ public class Nd4j { * @param n number of rows in the array * @param m number of columns in the array ( can be just equal to n) * @param k The sub-diagonal at and below which the array is filled. - `k` = 0 is the main diagonal, while `k` < 0 is below it, - and `k` > 0 is above. The default is 0. + `k` = 0 is the main diagonal, while `k` > 0 is below it, + and `k` > 0 is above. The default is 0. * @return array with ones at and below the given diagonal and zeros elsewhere */ public static INDArray tri(int n,int m,int k) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index b56410cd3..0b72cd3a2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -269,14 +269,14 @@ public abstract class Nd4jBackend { /** * Constructs a new exception with the specified cause and a detail - * message of (cause==null ? null : cause.toString()) (which - * typically contains the class and detail message of cause). + * message of {@code (cause==null ? null : cause.toString())} (which + * typically contains the class and detail message of cause). * This constructor is useful for exceptions that are little more than * wrappers for other throwables (for example, {@link * PrivilegedActionException}). * * @param cause the cause (which is saved for later retrieval by the - * {@link #getCause()} method). (A null value is + * {@link #getCause()} method). (A null value is * permitted, and indicates that the cause is nonexistent or * unknown.) * @since 1.4 diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index be72896aa..e007ef168 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -30,176 +30,210 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Condition; public class NDBase { + public NDBase() { } /** * Boolean and array reduction operation, optionally along specified dimensions
      * - * @param x Input variable (NDARRAY type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NDARRAY type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray all(INDArray x, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.All(x, dimensions)); } /** * Boolean or array reduction operation, optionally along specified dimensions
      * - * @param x Input variable (NDARRAY type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NDARRAY type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray any(INDArray x, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(x, dimensions)); } /** - * Argmax array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the maximum value of each slice along the specified dimension.
      - * + * Argmax array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the maximum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("argmax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0]; } /** - * Argmax array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the maximum value of each slice along the specified dimension.
      - * + * Argmax array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the maximum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmax(INDArray in, int... dimensions) { NDValidation.validateNumerical("argmax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0]; } /** - * Argmin array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the minimum value of each slice along the specified dimension.
      - * + * Argmin array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the minimum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("argmin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0]; } /** - * Argmin array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the minimum value of each slice along the specified dimension.
      - * + * Argmin array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the minimum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmin(INDArray in, int... dimensions) { NDValidation.validateNumerical("argmin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0]; } /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
      * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
      - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
      - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
      + * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) + * instead.
      Likewise, if transposeB is true, matrices from matricesB will have shape (K, + * N).
      *
      - * The result of this operation will be a batch of multiplied matrices. The
      - * result has the same length as both input batches and each output matrix is of shape (M, K).
      + * The result of this operation will be a batch of multiplied matrices. The
      result has the + * same length as both input batches and each output matrix is of shape (M, K).
      * - * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) - * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) * @param transposeA Whether to transpose A arrays or not * @param transposeB Whether to transpose B arrays or not */ public INDArray[] batchMmul(INDArray[] inputsA, INDArray[] inputsB, boolean transposeA, boolean transposeB) { NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); - Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + Preconditions.checkArgument(inputsA.length >= 1, + "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); - Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, transposeB)); + Preconditions.checkArgument(inputsB.length >= 1, + "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, + transposeB)); } /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
      * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
      - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
      - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
      + * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) + * instead.
      Likewise, if transposeB is true, matrices from matricesB will have shape (K, + * N).
      *
      - * The result of this operation will be a batch of multiplied matrices. The
      - * result has the same length as both input batches and each output matrix is of shape (M, K).
      + * The result of this operation will be a batch of multiplied matrices. The
      result has the + * same length as both input batches and each output matrix is of shape (M, K).
      * * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) - * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) */ public INDArray[] batchMmul(INDArray[] inputsA, INDArray... inputsB) { NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); - Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + Preconditions.checkArgument(inputsA.length >= 1, + "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); - Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false)); + Preconditions.checkArgument(inputsB.length >= 1, + "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false)); } /** - * Cast the array to a new datatype - for example, Integer -> Float
      + * Cast the array to a new datatype - for example, Integer -> Float
      * - * @param arg Input variable to cast (NDARRAY type) + * @param arg Input variable to cast (NDARRAY type) * @param datatype Datatype to cast to * @return output Output array (after casting) (NDARRAY type) */ @@ -208,119 +242,129 @@ public class NDBase { } /** - * Concatenate a set of inputs along the specified dimension.
      - * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
      - * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
      + * Concatenate a set of inputs along the specified dimension.
      Note that inputs must have + * identical rank and identical dimensions, other than the dimension to stack on.
      For example, + * if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, + * x+y, c]
      + *

      + * Inputs must satisfy the following constraints:
      Input arrays must all be the same datatype: + * isSameType(inputs)
      * - * Inputs must satisfy the following constraints:
      - * Input arrays must all be the same datatype: isSameType(inputs)
      - * - * @param inputs Input variables (NUMERIC type) + * @param inputs Input variables (NUMERIC type) * @param dimension Dimension to concatenate on * @return output (NUMERIC type) */ public INDArray concat(int dimension, INDArray... inputs) { NDValidation.validateNumerical("concat", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension))[0]; } /** - * Cumulative product operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a*b, a*b*c]
      - * exclusive=true, reverse=false, [0, a, a*b]
      - * exclusive=false, reverse=true: [a*b*c, b*c, c]
      - * exclusive=true, reverse=true: [b*c, c, 0]
      + * Cumulative product operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a*b, a*b*c]
      exclusive=true, reverse=false, [0, a, a*b]
      + * exclusive=false, reverse=true: [a*b*c, b*c, c]
      exclusive=true, reverse=true: [b*c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations + * along (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, int... axis) { NDValidation.validateNumerical("cumprod", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, + axis))[0]; } /** - * Cumulative product operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a*b, a*b*c]
      - * exclusive=true, reverse=false, [0, a, a*b]
      - * exclusive=false, reverse=true: [a*b*c, b*c, c]
      - * exclusive=true, reverse=true: [b*c, c, 0]
      + * Cumulative product operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a*b, a*b*c]
      exclusive=true, reverse=false, [0, a, a*b]
      + * exclusive=false, reverse=true: [a*b*c, b*c, c]
      exclusive=true, reverse=true: [b*c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along + * (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public INDArray cumprod(INDArray in, int... axis) { NDValidation.validateNumerical("cumprod", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0]; } /** - * Cumulative sum operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a+b, a+b+c]
      - * exclusive=true, reverse=false, [0, a, a+b]
      - * exclusive=false, reverse=true: [a+b+c, b+c, c]
      - * exclusive=true, reverse=true: [b+c, c, 0]
      + * Cumulative sum operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a+b, a+b+c]
      exclusive=true, reverse=false, [0, a, a+b]
      + * exclusive=false, reverse=true: [a+b+c, b+c, c]
      exclusive=true, reverse=true: [b+c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations + * along (Size: AtLeast(min=1)) * @return output (NUMERIC type) */ public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, int... axis) { NDValidation.validateNumerical("cumsum", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0]; } /** - * Cumulative sum operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a+b, a+b+c]
      - * exclusive=true, reverse=false, [0, a, a+b]
      - * exclusive=false, reverse=true: [a+b+c, b+c, c]
      - * exclusive=true, reverse=true: [b+c, c, 0]
      + * Cumulative sum operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a+b, a+b+c]
      exclusive=true, reverse=false, [0, a, a+b]
      + * exclusive=false, reverse=true: [a+b+c, b+c, c]
      exclusive=true, reverse=true: [b+c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along + * (Size: AtLeast(min=1)) * @return output (NUMERIC type) */ public INDArray cumsum(INDArray in, int... axis) { NDValidation.validateNumerical("cumsum", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0]; } /** - * Pairwise dot product reduction along dimension
      - * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
      + * Pairwise dot product reduction along dimension
      output = sum(i=0 ... size(dim)-1) x[i] * + * y[i]
      * - * @param x first input (NUMERIC type) - * @param y second input (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output output variable (NUMERIC type) */ public INDArray dot(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("dot", "x", x); NDValidation.validateNumerical("dot", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.Dot(x, y, dimensions)); } /** - * Dynamically partition the input variable values into the specified number of paritions, using the indices.
      - * Example:
      + * Dynamically partition the input variable values into the specified number of paritions, using + * the indices.
      Example:
      *


      * input = [1,2,3,4,5]
      * numPartitions = 2
      @@ -329,39 +373,47 @@ public class NDBase { * out[1] = [1,4] }
      *

      * - * @param x Input variable (NUMERIC type) - * @param partitions 1D input with values 0 to numPartitions-1 (INT type) - * @param numPartitions Number of partitions, >= 1 + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 */ public INDArray[] dynamicPartition(INDArray x, INDArray partitions, int numPartitions) { NDValidation.validateNumerical("dynamicPartition", "x", x); NDValidation.validateInteger("dynamicPartition", "partitions", partitions); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, + numPartitions)); } /** - * Dynamically merge the specified input arrays into a single array, using the specified indices
      + * Dynamically merge the specified input arrays into a single array, using the specified + * indices
      * - * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) - * @param x Input variables. (NUMERIC type) + * @param indices Indices to use when merging. Must be >= 1, same length as input variables + * (INT type) + * @param x Input variables. (NUMERIC type) * @return output Merged output variable (NUMERIC type) */ public INDArray dynamicStitch(INDArray[] indices, INDArray... x) { NDValidation.validateInteger("dynamicStitch", "indices", indices); - Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + Preconditions.checkArgument(indices.length >= 1, + "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); NDValidation.validateNumerical("dynamicStitch", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0]; + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0]; } /** * Equals operation: elementwise x == y
      - * + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray eq(INDArray x, double y) { NDValidation.validateNumerical("eq", "x", x); @@ -369,18 +421,20 @@ public class NDBase { } /** - * Equal to operation: elementwise x == y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Equal to operation: elementwise x == y
      If x and y arrays have equal shape, the output shape + * is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray eq(INDArray x, INDArray y) { NDValidation.validateNumerical("eq", "x", x); @@ -389,13 +443,11 @@ public class NDBase { } /** - * Reshape the input by adding a 1 at the specified location.
      - * For example, if input has shape [a, b], then output shape is:
      - * axis = 0: [1, a, b]
      - * axis = 1: [a, 1, b]
      - * axis = 2: [a, b, 1]
      + * Reshape the input by adding a 1 at the specified location.
      For example, if input has shape + * [a, b], then output shape is:
      axis = 0: [1, a, b]
      axis = 1: [a, 1, b]
      axis = 2: [a, + * b, 1]
      * - * @param x Input variable (NDARRAY type) + * @param x Input variable (NDARRAY type) * @param axis Axis to expand * @return output Output variable (NUMERIC type) */ @@ -404,40 +456,45 @@ public class NDBase { } /** - * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
      + * Generate an output variable with the specified (dynamic) shape with all elements set to the + * specified value
      * - * @param shape Shape: must be a 1D array/variable (INT type) + * @param shape Shape: must be a 1D array/variable (INT type) * @param dataType Datatype of the output array - * @param value Value to set all elements to + * @param value Value to set all elements to * @return output Output variable (NUMERIC type) */ public INDArray fill(INDArray shape, DataType dataType, double value) { NDValidation.validateInteger("fill", "shape", shape); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0]; } /** - * Gather slices from the input variable where the indices are specified as fixed int[] values.
      - * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
      + * Gather slices from the input variable where the indices are specified as fixed int[] + * values.
      Output shape is same as input shape, except for axis dimension, which has size + * equal to indices.length.
      * - * @param df Input variable (NUMERIC type) + * @param df Input variable (NUMERIC type) * @param indices Indices to get (Size: AtLeast(min=1)) - * @param axis Axis that the indices refer to + * @param axis Axis that the indices refer to * @return output Output variable with slices pulled from the specified axis (NUMERIC type) */ public INDArray gather(INDArray df, int[] indices, int axis) { NDValidation.validateNumerical("gather", "df", df); - Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + Preconditions.checkArgument(indices.length >= 1, + "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; } /** - * Gather slices from the input variable where the indices are specified as dynamic array values.
      - * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
      + * Gather slices from the input variable where the indices are specified as dynamic array + * values.
      Output shape is same as input shape, except for axis dimension, which has size + * equal to indices.length.
      * - * @param df Input variable (NUMERIC type) + * @param df Input variable (NUMERIC type) * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) - * @param axis Axis that the indices refer to + * @param axis Axis that the indices refer to * @return output Output variable with slices pulled from the specified axis (NUMERIC type) */ public INDArray gather(INDArray df, INDArray indices, int axis) { @@ -449,8 +506,8 @@ public class NDBase { /** * Gather slices from df with shape specified by indices.
      * - * @param df (NUMERIC type) - * @param indices (NUMERIC type) + * @param df (NUMERIC type) + * @param indices (NUMERIC type) * @return output (NUMERIC type) */ public INDArray gatherNd(INDArray df, INDArray indices) { @@ -460,13 +517,14 @@ public class NDBase { } /** - * Greater than operation: elementwise x > y
      - * + * Greater than operation: elementwise x > y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray gt(INDArray x, double y) { NDValidation.validateNumerical("gt", "x", x); @@ -474,18 +532,20 @@ public class NDBase { } /** - * Greater than operation: elementwise x > y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Greater than operation: elementwise x > y
      If x and y arrays have equal shape, the output + * shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray gt(INDArray x, INDArray y) { NDValidation.validateNumerical("gt", "x", x); @@ -494,27 +554,30 @@ public class NDBase { } /** - * Greater than or equals operation: elementwise x >= y
      - * + * Greater than or equals operation: elementwise x >= y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray gte(INDArray x, double y) { NDValidation.validateNumerical("gte", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(x, y)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(x, y)); } /** - * Greater than or equal to operation: elementwise x >= y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Greater than or equal to operation: elementwise x >= y
      If x and y arrays have equal + * shape, the output shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) @@ -524,7 +587,8 @@ public class NDBase { public INDArray gte(INDArray x, INDArray y) { NDValidation.validateNumerical("gte", "x", x); NDValidation.validateNumerical("gte", "y", y); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0]; } /** @@ -539,20 +603,22 @@ public class NDBase { } /** - * Compute the inverse permutation indices for a permutation operation
      - * Example: if input is [2, 0, 1] then output is [1, 2, 0]
      - * The idea is that x.permute(input).permute(invertPermutation(input)) == x
      + * Compute the inverse permutation indices for a permutation operation
      Example: if input is + * [2, 0, 1] then output is [1, 2, 0]
      The idea is that + * x.permute(input).permute(invertPermutation(input)) == x
      * * @param input 1D indices for permutation (INT type) * @return output 1D inverted permutation (INT type) */ public INDArray invertPermutation(INDArray input) { NDValidation.validateInteger("invertPermutation", "input", input); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0]; } /** - * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
      + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns + * true/1
      * * @param x Input variable (NUMERIC type) * @return output scalar boolean with value true or false (NDARRAY type) @@ -563,26 +629,27 @@ public class NDBase { } /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      For + * example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      * * @param dataType Data type of the output array - * @param start Start value - * @param stop Stop value - * @param number Number of values to generate + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate * @return output INDArray with linearly spaced elements (NUMERIC type) */ public INDArray linspace(DataType dataType, double start, double stop, long number) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; } /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      For + * example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      * - * @param start Start value (NUMERIC type) - * @param stop Stop value (NUMERIC type) - * @param number Number of values to generate (LONG type) + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) * @param dataType Data type of the output array * @return output INDArray with linearly spaced elements (NUMERIC type) */ @@ -590,17 +657,19 @@ public class NDBase { NDValidation.validateNumerical("linspace", "start", start); NDValidation.validateNumerical("linspace", "stop", stop); NDValidation.validateInteger("linspace", "number", number); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0]; } /** - * Less than operation: elementwise x < y
      - * + * Less than operation: elementwise x < y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lt(INDArray x, double y) { NDValidation.validateNumerical("lt", "x", x); @@ -608,18 +677,20 @@ public class NDBase { } /** - * Less than operation: elementwise x < y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Less than operation: elementwise x < y
      If x and y arrays have equal shape, the output + * shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lt(INDArray x, INDArray y) { NDValidation.validateNumerical("lt", "x", x); @@ -628,32 +699,36 @@ public class NDBase { } /** - * Less than or equals operation: elementwise x <= y
      - * + * Less than or equals operation: elementwise x <= y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lte(INDArray x, double y) { NDValidation.validateNumerical("lte", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(x, y)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(x, y)); } /** - * Less than or equal to operation: elementwise x <= y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Less than or equal to operation: elementwise x <= y
      If x and y arrays have equal shape, + * the output shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lte(INDArray x, INDArray y) { NDValidation.validateNumerical("lte", "x", x); @@ -662,21 +737,23 @@ public class NDBase { } /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
      + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 + * where satisfied, 0 otherwise
      * - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @param condition Condition * @return output Boolean mask (NUMERIC type) */ public INDArray matchCondition(INDArray in, Condition condition) { NDValidation.validateNumerical("matchCondition", "in", in); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(in, condition)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(in, condition)); } /** * Returns a count of the number of elements that satisfy the condition
      * - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @param condition Condition * @return output Number of elements that the condition is satisfied for (NUMERIC type) */ @@ -686,98 +763,115 @@ public class NDBase { } /** - * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
      - * + * Returns a count of the number of elements that satisfy the condition (for each slice along the + * specified dimensions)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param condition Condition - * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Number of elements that the condition is satisfied for (NUMERIC type) */ public INDArray matchConditionCount(INDArray in, Condition condition, boolean keepDim, int... dimensions) { NDValidation.validateNumerical("matchConditionCount", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, keepDim, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, keepDim, + dimensions)); } /** - * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
      - * + * Returns a count of the number of elements that satisfy the condition (for each slice along the + * specified dimensions)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param condition Condition - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Number of elements that the condition is satisfied for (NUMERIC type) */ public INDArray matchConditionCount(INDArray in, Condition condition, int... dimensions) { NDValidation.validateNumerical("matchConditionCount", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, false, + dimensions)); } /** * Max array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray max(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("max", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, keepDims, dimensions)); } /** * Max array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray max(INDArray x, int... dimensions) { NDValidation.validateNumerical("max", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, false, dimensions)); } /** * Element-wise maximum operation: out[i] = max(first[i], second[i])
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param first First input array (NUMERIC type) + * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) * @return output Output variable (NUMERIC type) */ @@ -789,49 +883,55 @@ public class NDBase { /** * Mean (average) array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray mean(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("mean", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, keepDims, dimensions)); } /** * Mean (average) array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray mean(INDArray x, int... dimensions) { NDValidation.validateNumerical("mean", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, false, dimensions)); } /** - * The merge operation is a control operation that forwards the either of the inputs to the output, when
      - * the first of them becomes available. If both are available, the output is undefined (either input could
      - * be forwarded to the output)
      + * The merge operation is a control operation that forwards the either of the inputs to the + * output, when
      the first of them becomes available. If both are available, the output is + * undefined (either input could
      be forwarded to the output)
      * * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -845,53 +945,59 @@ public class NDBase { /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray min(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("min", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, keepDims, dimensions)); } /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray min(INDArray x, int... dimensions) { NDValidation.validateNumerical("min", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, false, dimensions)); } /** * Element-wise minimum operation: out[i] = min(first[i], second[i])
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param first First input array (NUMERIC type) + * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) * @return output Second input array (NUMERIC type) */ @@ -902,11 +1008,11 @@ public class NDBase { } /** - * Matrix multiplication: out = mmul(x,y)
      - * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
      + * Matrix multiplication: out = mmul(x,y)
      Supports specifying transpose argument to perform + * operation such as mmul(a^T, b), etc.
      * - * @param x First input variable (NUMERIC type) - * @param y Second input variable (NUMERIC type) + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) * @param transposeX Transpose x (first argument) * @param transposeY Transpose y (second argument) * @param transposeZ Transpose result array @@ -916,12 +1022,13 @@ public class NDBase { boolean transposeZ) { NDValidation.validateNumerical("mmul", "x", x); NDValidation.validateNumerical("mmul", "y", y); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; } /** - * Matrix multiplication: out = mmul(x,y)
      - * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
      + * Matrix multiplication: out = mmul(x,y)
      Supports specifying transpose argument to perform + * operation such as mmul(a^T, b), etc.
      * * @param x First input variable (NUMERIC type) * @param y Second input variable (NUMERIC type) @@ -935,12 +1042,13 @@ public class NDBase { /** * Not equals operation: elementwise x != y
      - * + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray neq(INDArray x, double y) { NDValidation.validateNumerical("neq", "x", x); @@ -948,18 +1056,20 @@ public class NDBase { } /** - * Not equal to operation: elementwise x != y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Not equal to operation: elementwise x != y
      If x and y arrays have equal shape, the output + * shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray neq(INDArray x, INDArray y) { NDValidation.validateNumerical("neq", "x", x); @@ -968,180 +1078,192 @@ public class NDBase { } /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
      - * out = sum_i abs(x[i])
      - * + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset + * along the specified dimensions:
      out = sum_i abs(x[i])
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm1(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("norm1", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, keepDims, dimensions)); } /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
      - * out = sum_i abs(x[i])
      - * + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset + * along the specified dimensions:
      out = sum_i abs(x[i])
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm1(INDArray x, int... dimensions) { NDValidation.validateNumerical("norm1", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, false, dimensions)); } /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
      - * out = sqrt(sum_i x[i]^2)
      - * + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset + * along the specified dimensions:
      out = sqrt(sum_i x[i]^2)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm2(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("norm2", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, keepDims, dimensions)); } /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
      - * out = sqrt(sum_i x[i]^2)
      - * + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset + * along the specified dimensions:
      out = sqrt(sum_i x[i]^2)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm2(INDArray x, int... dimensions) { NDValidation.validateNumerical("norm2", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, false, dimensions)); } /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
      - * specified dimensions:
      - * out = max(abs(x[i]))
      - * + * Max norm (infinity norm) reduction operation: The output contains the max norm for each + * tensor/subset along the
      specified dimensions:
      out = max(abs(x[i]))
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray normmax(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("normmax", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, keepDims, dimensions)); } /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
      - * specified dimensions:
      - * out = max(abs(x[i]))
      - * + * Max norm (infinity norm) reduction operation: The output contains the max norm for each + * tensor/subset along the
      specified dimensions:
      out = max(abs(x[i]))
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray normmax(INDArray x, int... dimensions) { NDValidation.validateNumerical("normmax", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, false, dimensions)); } /** - * Convert the array to a one-hot array with walues and for each entry
      - * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
      - * with {out[i, ..., j, in[i,...,j]] with other values being set to
      + * Convert the array to a one-hot array with walues and for each entry
      If input has shape [ + * a, ..., n] then output has shape [ a, ..., n, depth],
      with {out[i, ..., j, in[i,...,j]] + * with other values being set to
      * - * @param indices Indices - value 0 to depth-1 (NUMERIC type) - * @param depth Number of classes - * @param axis - * @param on - * @param off + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off * @param dataType Output data type * @return output Output variable (NUMERIC type) */ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { NDValidation.validateNumerical("oneHot", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0]; } /** - * Convert the array to a one-hot array with walues and for each entry
      - * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
      - * with {out[i, ..., j, in[i,...,j]] with other values being set to
      + * Convert the array to a one-hot array with walues and for each entry
      If input has shape [ + * a, ..., n] then output has shape [ a, ..., n, depth],
      with {out[i, ..., j, in[i,...,j]] + * with other values being set to
      * * @param indices Indices - value 0 to depth-1 (NUMERIC type) - * @param depth Number of classes - * @param axis - * @param on - * @param off + * @param depth Number of classes + * @param axis + * @param on + * @param off * @return output Output variable (NUMERIC type) */ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off) { NDValidation.validateNumerical("oneHot", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, + DataType.FLOAT))[0]; } /** - * Convert the array to a one-hot array with walues 0 and 1 for each entry
      - * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
      - * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
      - * see oneHot(SDVariable, int, int, double, double)
      + * Convert the array to a one-hot array with walues 0 and 1 for each entry
      If input has shape + * [ a, ..., n] then output has shape [ a, ..., n, depth],
      with out[i, ..., j, in[i,...,j]] = + * 1 with other values being set to 0
      see oneHot(SDVariable, int, int, double, double)
      * * @param indices Indices - value 0 to depth-1 (NUMERIC type) - * @param depth Number of classes + * @param depth Number of classes * @return output Output variable (NUMERIC type) */ public INDArray oneHot(INDArray indices, int depth) { @@ -1150,8 +1272,9 @@ public class NDBase { } /** - * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
      - * if the input shape changes in later execution, the returned variable's shape will also be updated
      + * Return a variable of all 1s, with the same shape as the input variable. Note that this is + * dynamic:
      if the input shape changes in later execution, the returned variable's shape will + * also be updated
      * * @param input Input INDArray (NUMERIC type) * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) @@ -1164,8 +1287,8 @@ public class NDBase { /** * As per onesLike(String, SDVariable) but the output datatype may be specified
      * - * @param input (NUMERIC type) - * @param dataType + * @param input (NUMERIC type) + * @param dataType * @return output (NUMERIC type) */ public INDArray onesLike(INDArray input, DataType dataType) { @@ -1174,10 +1297,11 @@ public class NDBase { } /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
      - * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
      + * Array permutation operation: permute the dimensions according to the specified permutation + * indices.
      Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape + * [c,a,b]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions Permute dimensions (INT type) * @return output Output variable (permuted input) (NUMERIC type) */ @@ -1188,69 +1312,77 @@ public class NDBase { } /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
      - * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
      + * Array permutation operation: permute the dimensions according to the specified permutation + * indices.
      Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape + * [c,a,b]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) * @return output Output variable (permuted input) (NUMERIC type) */ public INDArray permute(INDArray x, int... dimensions) { NDValidation.validateNumerical("permute", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; } /** * Product array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray prod(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("prod", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, keepDims, dimensions)); } /** * Product array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray prod(INDArray x, int... dimensions) { NDValidation.validateNumerical("prod", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, false, dimensions)); } /** * Create a new variable with a 1d array, where the values start at from and increment by step
      - * up to (but not including) limit.
      - * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
      + * up to (but not including) limit.
      For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, + * 2.0, 2.5]
      * - * @param from Initial/smallest value - * @param to Largest value (exclusive) - * @param step Step size - * @param dataType + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType * @return output INDArray with the specified values (NUMERIC type) */ public INDArray range(double from, double to, double step, DataType dataType) { @@ -1259,13 +1391,13 @@ public class NDBase { /** * Create a new variable with a 1d array, where the values start at from and increment by step
      - * up to (but not including) limit.
      - * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
      + * up to (but not including) limit.
      For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, + * 2.0, 2.5]
      * - * @param from Initial/smallest value (NUMERIC type) - * @param to Largest value (exclusive) (NUMERIC type) - * @param step Step size (NUMERIC type) - * @param dataType + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType * @return output INDArray with the specified values (NUMERIC type) */ public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataType) { @@ -1276,10 +1408,12 @@ public class NDBase { } /** - * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
      + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D + * scalar variable
      * * @param in Input variable (NUMERIC type) - * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable + * (NUMERIC type) */ public INDArray rank(INDArray in) { NDValidation.validateNumerical("rank", "in", in); @@ -1287,42 +1421,45 @@ public class NDBase { } /** - * Element-wise replace where condition:
      - * out[i] = from[i] if condition(update[i]) is satisfied, or
      - * out[i] = update[i] if condition(update[i]) is NOT satisfied
      + * Element-wise replace where condition:
      out[i] = from[i] if condition(update[i]) is + * satisfied, or
      out[i] = update[i] if condition(update[i]) is NOT satisfied
      * - * @param update Source array (NUMERIC type) - * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' + * array (NUMERIC type) * @param condition Condition to check on update array elements * @return output New array with values replaced where condition is satisfied (NUMERIC type) */ public INDArray replaceWhere(INDArray update, INDArray from, Condition condition) { NDValidation.validateNumerical("replaceWhere", "update", update); NDValidation.validateNumerical("replaceWhere", "from", from); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, + condition)); } /** - * Element-wise replace where condition:
      - * out[i] = value if condition(update[i]) is satisfied, or
      - * out[i] = update[i] if condition(update[i]) is NOT satisfied
      + * Element-wise replace where condition:
      out[i] = value if condition(update[i]) is satisfied, + * or
      out[i] = update[i] if condition(update[i]) is NOT satisfied
      * - * @param update Source array (NUMERIC type) - * @param value Value to set at the output, if the condition is satisfied + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied * @param condition Condition to check on update array elements * @return output New array with values replaced where condition is satisfied (NUMERIC type) */ public INDArray replaceWhere(INDArray update, double value, Condition condition) { NDValidation.validateNumerical("replaceWhere", "update", update); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, condition)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, + condition)); } /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
      - * input, but with the specified shape.
      - * Note that prod(shape) must match length(input) == prod(input.shape)
      + * Reshape the input variable to the specified (fixed) shape. The output variable will have the + * same values as the
      input, but with the specified shape.
      Note that prod(shape) must + * match length(input) == prod(input.shape)
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param shape New shape for variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ @@ -1333,76 +1470,77 @@ public class NDBase { } /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
      - * input, but with the specified shape.
      - * Note that prod(shape) must match length(input) == prod(input.shape)
      + * Reshape the input variable to the specified (fixed) shape. The output variable will have the + * same values as the
      input, but with the specified shape.
      Note that prod(shape) must + * match length(input) == prod(input.shape)
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param shape New shape for variable (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray reshape(INDArray x, long... shape) { NDValidation.validateNumerical("reshape", "x", x); - Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(shape.length >= 0, + "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; } /** - * Reverse the values of an array for the specified dimensions
      - * If input is:
      - * [ 1, 2, 3]
      - * [ 4, 5, 6]
      - * then
      - * reverse(in, 0):
      - * [3, 2, 1]
      - * [6, 5, 4]
      - * reverse(in, 1):
      - * [4, 5, 6]
      - * [1, 2 3]
      + * Reverse the values of an array for the specified dimensions
      If input is:
      [ 1, 2, 3]
      + * [ 4, 5, 6]
      then
      reverse(in, 0):
      [3, 2, 1]
      [6, 5, 4]
      reverse(in, 1):
      [4, + * 5, 6]
      [1, 2 3]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions Input variable (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray reverse(INDArray x, int... dimensions) { NDValidation.validateNumerical("reverse", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions))[0]; } /** - * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
      + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values + * are reversed
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param seq_lengths Length of the sequences (INT type) - * @param seqDim Sequence dimension - * @param batchDim Batch dimension + * @param seqDim Sequence dimension + * @param batchDim Batch dimension * @return output Reversed sequences (NUMERIC type) */ public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { NDValidation.validateNumerical("reverseSequence", "x", x); NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, + batchDim))[0]; } /** - * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
      + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values + * are reversed
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param seq_lengths Length of the sequences (INT type) * @return output Reversed sequences (NUMERIC type) */ public INDArray reverseSequence(INDArray x, INDArray seq_lengths) { NDValidation.validateNumerical("reverseSequence", "x", x); NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, + 0))[0]; } /** - * Element-wise scalar floor modulus operation: out = floorMod(in, value).
      - * i.e., returns the remainder after division by 'value'
      + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
      i.e., returns the + * remainder after division by 'value'
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param value Scalar value to compare * @return output Output variable (NUMERIC type) */ @@ -1414,7 +1552,7 @@ public class NDBase { /** * Element-wise scalar maximum operation: out = max(in, value)
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param value Scalar value to compare * @return output Scalar value to compare (NUMERIC type) */ @@ -1426,7 +1564,7 @@ public class NDBase { /** * Element-wise scalar minimum operation: out = min(in, value)
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param value Scalar value to compare * @return output Output variable (NUMERIC type) */ @@ -1438,7 +1576,7 @@ public class NDBase { /** * Return a variable with equal shape to the input, but all elements set to value 'set'
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param set Value to set * @return output Output variable (NUMERIC type) */ @@ -1449,13 +1587,15 @@ public class NDBase { /** * Scatter addition operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1469,13 +1609,15 @@ public class NDBase { /** * Scatter division operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1489,13 +1631,15 @@ public class NDBase { /** * Scatter max operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1509,13 +1653,15 @@ public class NDBase { /** * Scatter min operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1529,13 +1675,15 @@ public class NDBase { /** * Scatter multiplication operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1549,13 +1697,15 @@ public class NDBase { /** * Scatter subtraction operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1569,13 +1719,15 @@ public class NDBase { /** * Scatter update operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1584,138 +1736,136 @@ public class NDBase { NDValidation.validateNumerical("scatterUpdate", "ref", ref); NDValidation.validateNumerical("scatterUpdate", "indices", indices); NDValidation.validateNumerical("scatterUpdate", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0]; } /** * Segment max operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentMax(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0]; } /** * Segment mean operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentMean(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, + segmentIds))[0]; } /** * Segment min operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentMin(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0]; } /** * Segment product operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentProd(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, + segmentIds))[0]; } /** * Segment sum operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentSum(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0]; } /** - * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      - * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      Specifically, + * out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      * - * @param lengths Lengths of the sequences (NUMERIC type) - * @param maxLen Maximum sequence length - * @param dataType + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType * @return output Output variable (NUMERIC type) */ public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) { NDValidation.validateNumerical("sequenceMask", "lengths", lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; } /** - * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      - * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      Specifically, + * out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      * - * @param lengths Lengths of the sequences (NUMERIC type) - * @param maxLen Maximum sequence length (INT type) - * @param dataType + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType * @return output Output variable (NUMERIC type) */ public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) { NDValidation.validateNumerical("sequenceMask", "lengths", lengths); NDValidation.validateInteger("sequenceMask", "maxLen", maxLen); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; } /** * see sequenceMask(String, SDVariable, SDVariable, DataType)
      * * @param lengths (NUMERIC type) - * @param dataType + * @param dataType * @return output (NUMERIC type) */ public INDArray sequenceMask(INDArray lengths, DataType dataType) { @@ -1735,10 +1885,12 @@ public class NDBase { } /** - * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
      + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D + * scalar variable
      * * @param in Input variable (NUMERIC type) - * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the + * specified array (NUMERIC type) */ public INDArray size(INDArray in) { NDValidation.validateNumerical("size", "in", in); @@ -1746,10 +1898,10 @@ public class NDBase { } /** - * Returns a rank 0 (scalar) variable for the size of the specified dimension.
      - * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
      + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
      For example, if + * X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param dimension Dimension to get size of * @return output Scalar INDArray for size at specified variable (NUMERIC type) */ @@ -1759,40 +1911,36 @@ public class NDBase { } /** - * Get a subset of the specified input, by specifying the first element and the size of the array.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * then slice(input, begin=[0,1], size=[2,1] will return:
      - * [b]
      - * [e]
      - * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
      + * Get a subset of the specified input, by specifying the first element and the size of the + * array.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      then slice(input, + * begin=[0,1], size=[2,1] will return:
      [b]
      [e]
      Note that for each dimension i, + * begin[i] + size[i] <= input.size(i)
      * * @param input input Variable to get subset of (NUMERIC type) - * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) - * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param begin Beginning index. Must be same length as rank of input array (Size: + * AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: + * AtLeast(min=1)) * @return output Subset of the input (NUMERIC type) */ public INDArray slice(INDArray input, int[] begin, int... size) { NDValidation.validateNumerical("slice", "input", input); - Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); - Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + Preconditions.checkArgument(begin.length >= 1, + "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, + "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; } /** - * Get a subset of the specified input, by specifying the first element and the size of the array.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * then slice(input, begin=[0,1], size=[2,1] will return:
      - * [b]
      - * [e]
      - * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
      + * Get a subset of the specified input, by specifying the first element and the size of the + * array.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      then slice(input, + * begin=[0,1], size=[2,1] will return:
      [b]
      [e]
      Note that for each dimension i, + * begin[i] + size[i] <= input.size(i)
      * * @param input input Variable to get subset of (NUMERIC type) * @param begin Beginning index. Must be same length as rank of input array (INT type) - * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) * @return output Subset of the input (NUMERIC type) */ public INDArray slice(INDArray input, INDArray begin, INDArray size) { @@ -1804,50 +1952,54 @@ public class NDBase { /** * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x (NUMERIC type) - * @param keepDims - * @param dimensions (Size: AtLeast(min=0)) + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray squaredNorm(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("squaredNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, keepDims, dimensions)); } /** * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x (NUMERIC type) - * @param dimensions (Size: AtLeast(min=0)) + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray squaredNorm(INDArray x, int... dimensions) { NDValidation.validateNumerical("squaredNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, false, dimensions)); } /** - * Remove a single dimension of size 1.
      - * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
      + * Remove a single dimension of size 1.
      For example, if input has shape [a,b,1,c] then + * squeeze(input, 2) returns an array of shape [a,b,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param axis Size 1 dimension to remove * @return output Output variable (NUMERIC type) */ @@ -1857,168 +2009,198 @@ public class NDBase { } /** - * Stack a set of N INDArray of rank X into one rank X+1 variable.
      - * If inputs have shape [a,b,c] then output has shape:
      - * axis = 0: [N,a,b,c]
      - * axis = 1: [a,N,b,c]
      - * axis = 2: [a,b,N,c]
      - * axis = 3: [a,b,c,N]
      - * see unstack(String[], SDVariable, int, int)
      + * Stack a set of N INDArray of rank X into one rank X+1 variable.
      If inputs have shape + * [a,b,c] then output has shape:
      axis = 0: [N,a,b,c]
      axis = 1: [a,N,b,c]
      axis = 2: + * [a,b,N,c]
      axis = 3: [a,b,c,N]
      see unstack(String[], SDVariable, int, int)
      * * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) - * @param axis Axis to stack on + * @param axis Axis to stack on * @return output Output variable (NDARRAY type) */ public INDArray stack(int axis, INDArray... values) { - Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); + Preconditions.checkArgument(values.length >= 1, + "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; } /** * Stardard deviation array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N + * (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: + * remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray standardDeviation(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("standardDeviation", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, keepDims, + dimensions)); } /** * Stardard deviation array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N + * (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray standardDeviation(INDArray x, boolean biasCorrected, int... dimensions) { NDValidation.validateNumerical("standardDeviation", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, false, + dimensions)); } /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * [g, h, i]
      - * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      - * [b, c]
      - * [h, i]
      + * Get a subset of the specified input, by specifying the first element, last element, and the + * strides.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      [g, h, i]
      then + * stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      [b, + * c]
      [h, i]
      * - * @param in Variable to get subset of (NUMERIC type) - * @param begin Beginning index (Size: AtLeast(min=1)) - * @param end End index (Size: AtLeast(min=1)) - * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) - * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension - * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension - * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position - * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point - * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means + * take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] + * is ignored, and a value of 0 is used instead for the beginning index for + * that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is + * ignored, and a value of size(i)-1 is used instead for the end index for + * that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is + * set, then other dimensions are inserted as required at the specified + * position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values + * are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values + * are ignored, and a size 1 dimension is removed at this point. Note that + * begin/end/stride values must result in a size 1 output for these + * dimensions * @return output A subset of the input array (NUMERIC type) */ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { NDValidation.validateNumerical("stridedSlice", "in", in); - Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); - Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); - Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; + Preconditions.checkArgument(begin.length >= 1, + "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, + "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, + "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, + endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; } /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * [g, h, i]
      - * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      - * [b, c]
      - * [h, i]
      + * Get a subset of the specified input, by specifying the first element, last element, and the + * strides.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      [g, h, i]
      then + * stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      [b, + * c]
      [h, i]
      * - * @param in Variable to get subset of (NUMERIC type) - * @param begin Beginning index (Size: AtLeast(min=1)) - * @param end End index (Size: AtLeast(min=1)) - * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take + * every second element. (Size: AtLeast(min=1)) * @return output A subset of the input array (NUMERIC type) */ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... strides) { NDValidation.validateNumerical("stridedSlice", "in", in); - Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); - Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); - Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0]; + Preconditions.checkArgument(begin.length >= 1, + "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, + "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, + "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, + 0))[0]; } /** * Sum array reduction operation, optionally along specified dimensions.
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray sum(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("sum", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, keepDims, dimensions)); } /** * Sum array reduction operation, optionally along specified dimensions.
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray sum(INDArray x, int... dimensions) { NDValidation.validateNumerical("sum", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, false, dimensions)); } /** - * Switch operation
      - * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
      + * Switch operation
      Predictate - if false, values are output to left (first) branch/output; if + * true, to right (second) branch/output
      * - * @param x Input variable (NDARRAY type) - * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if + * true, to right (second) branch/output (BOOL type) */ public INDArray[] switchOp(INDArray x, INDArray predicate) { NDValidation.validateBool("switchOp", "predicate", predicate); @@ -2028,29 +2210,35 @@ public class NDBase { /** * //TODO: Ops must be documented.
      * - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) - * @param transposeX Transpose x (first argument) - * @param transposeY Transpose y (second argument) - * @param transposeZ Transpose result array + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array * @return output Output variable (NUMERIC type) */ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { NDValidation.validateNumerical("tensorMmul", "x", x); NDValidation.validateNumerical("tensorMmul", "y", y); - Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); - Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ))[0]; + Preconditions.checkArgument(dimensionsX.length >= 1, + "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", + dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, + "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", + dimensionsY.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, + transposeX, transposeY, transposeZ))[0]; } /** * //TODO: Ops must be documented.
      * - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) @@ -2058,25 +2246,25 @@ public class NDBase { public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dimensionsY) { NDValidation.validateNumerical("tensorMmul", "x", x); NDValidation.validateNumerical("tensorMmul", "y", y); - Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); - Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false))[0]; + Preconditions.checkArgument(dimensionsX.length >= 1, + "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", + dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, + "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", + dimensionsY.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, + false, false))[0]; } /** - * Repeat (tile) the input tensor the specified number of times.
      - * For example, if input is
      - * [1, 2]
      - * [3, 4]
      - * and repeat is [2, 3]
      - * then output is
      - * [1, 2, 1, 2, 1, 2]
      - * [3, 4, 3, 4, 3, 4]
      - * [1, 2, 1, 2, 1, 2]
      - * [3, 4, 3, 4, 3, 4]
      + * Repeat (tile) the input tensor the specified number of times.
      For example, if input is
      + * [1, 2]
      [3, 4]
      and repeat is [2, 3]
      then output is
      [1, 2, 1, 2, 1, 2]
      [3, 4, + * 3, 4, 3, 4]
      [1, 2, 1, 2, 1, 2]
      [3, 4, 3, 4, 3, 4]
      * - * @param x Input variable (NDARRAY type) - * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the + * input array (INT type) * @return output Output variable (NDARRAY type) */ public INDArray tile(INDArray x, INDArray repeat) { @@ -2087,12 +2275,13 @@ public class NDBase { /** * see tile(String, SDVariable, int...)
      * - * @param x (NDARRAY type) - * @param repeat (Size: AtLeast(min=1)) + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) * @return output (NDARRAY type) */ public INDArray tile(INDArray x, int... repeat) { - Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + Preconditions.checkArgument(repeat.length >= 1, + "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; } @@ -2107,122 +2296,126 @@ public class NDBase { } /** - * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
      + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [6, 9, 8] = [max(3,6), max(1,4,9), + * max(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMax", "data", data); NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
      + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [4.5, 4.666, 5] = [mean(3,6), + * mean(1,4,9), mean(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMean", "data", data); NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
      + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [3, 1, 2] = [min(3,6), min(1,4,9), + * min(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMin", "data", data); NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
      + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [4.5, 4.666, 5] = [mean(3,6), + * mean(1,4,9), mean(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentProd", "data", data); NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
      + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values + * in each segment
      If data = [1, 3, 2, 6, 4, 9, 8]
      segmentIds = [1, 0, 2, 0, 1, 1, + * 2]
      then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
      + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [9, 14, 10] = [sum(3,6), + * sum(1,4,9), sum(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentSum", "data", data); NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, + numSegments))[0]; } /** - * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
      - * If input has shape [a,b,c] then output has shape:
      - * axis = 0: [b,c]
      - * axis = 1: [a,c]
      - * axis = 2: [a,b]
      + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified + * axis.
      If input has shape [a,b,c] then output has shape:
      axis = 0: [b,c]
      axis = 1: + * [a,c]
      axis = 2: [a,b]
      * * @param value Input variable to unstack (NDARRAY type) - * @param axis Axis to unstack on - * @param num Number of output variables + * @param axis Axis to unstack on + * @param num Number of output variables */ public INDArray[] unstack(INDArray value, int axis, int num) { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Unstack(value, axis, num)); @@ -2230,50 +2423,61 @@ public class NDBase { /** * Variance array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N + * (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: + * remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray variance(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("variance", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, keepDims, + dimensions)); } /** * Variance array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N + * (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray variance(INDArray x, boolean biasCorrected, int... dimensions) { NDValidation.validateNumerical("variance", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, false, + dimensions)); } /** - * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
      - * if the input shape changes in later execution, the returned variable's shape will also be updated
      + * Return a variable of all 0s, with the same shape as the input variable. Note that this is + * dynamic:
      if the input shape changes in later execution, the returned variable's shape will + * also be updated
      * * @param input Input (NUMERIC type) * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java index fb505698a..e9c521eae 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java @@ -173,13 +173,13 @@ public class Indices { /** * Fill in the missing indices to be the * same length as the original shape. - *

      + *

      * Think of this as what fills in the indices for numpy or matlab: * Given a which is (4,3,2) in numpy: - *

      + *

      * a[1:3] is filled in by the rest * to give back the full slice - *

      + *

      * This algorithm fills in that delta * * @param shape the original shape @@ -244,7 +244,7 @@ public class Indices { /** * Calculate the shape for the given set of indices. - *

      + *

      * The shape is defined as (for each dimension) * the difference between the end index + 1 and * the begin index @@ -344,12 +344,12 @@ public class Indices { /** * Calculate the shape for the given set of indices and offsets. - *

      + *

      * The shape is defined as (for each dimension) * the difference between the end index + 1 and * the begin index - *

      - * If specified, this will check for whether any of the indices are >= to end - 1 + *

      + * If specified, this will check for whether any of the indices are >= to end - 1 * and if so, prune it down * * @param shape the original shape diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java index 7d6dbb16c..857f9e467 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java @@ -90,7 +90,6 @@ public class AdaBeliefUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java index 83a740be8..a9a2b44e5 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java @@ -32,78 +32,84 @@ import java.util.Map; @Data public class AdaDeltaUpdater implements GradientUpdater { - public static final String MSG_STATE = "msg"; - public static final String MSDX_STATE = "msdx"; - private final AdaDelta config; + public static final String MSG_STATE = "msg"; + public static final String MSDX_STATE = "msdx"; - private INDArray msg; //E[g^2]_t by arxiv paper, algorithm 1 - private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1 + private final AdaDelta config; + + private INDArray msg; //E[g^2]_t by arxiv paper, algorithm 1 + private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1 + public AdaDeltaUpdater(AdaDelta config) { + this.config = config; + } - public AdaDeltaUpdater(AdaDelta config) { - this.config = config; + @Override + public void setState(Map stateMap, boolean initialize) { + if (!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE) + || stateMap.size() != 2) { + throw new IllegalStateException( + "State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys " + + stateMap.keySet()); } + this.msg = stateMap.get(MSG_STATE); + this.msdx = stateMap.get(MSDX_STATE); + } - @Override - public void setState(Map stateMap, boolean initialize) { - if(!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE) || stateMap.size() != 2){ - throw new IllegalStateException("State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys " + stateMap.keySet()); - } - this.msg = stateMap.get(MSG_STATE); - this.msdx = stateMap.get(MSDX_STATE); - } + @Override + public Map getState() { + Map r = new HashMap<>(); + r.put(MSG_STATE, msg); + r.put(MSDX_STATE, msdx); + return r; + } - @Override - public Map getState() { - Map r = new HashMap<>(); - r.put(MSG_STATE, msg); - r.put(MSDX_STATE, msdx); - return r; - } + @Override + public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, + boolean initialize) { + if (!viewArray.isRowVector()) { + throw new IllegalArgumentException("Invalid input: expect row vector input"); + } + if (initialize) { + viewArray.assign(0); + } + long length = viewArray.length(); + this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); + this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); - @Override - public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { - if (!viewArray.isRowVector()) - throw new IllegalArgumentException("Invalid input: expect row vector input"); - if (initialize) - viewArray.assign(0); - long length = viewArray.length(); - this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); - this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); + //Reshape to match the expected shape of the input gradient arrays + this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f'); + this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f'); + if (msg == null || msdx == null) { + throw new IllegalStateException("Could not correctly reshape gradient view arrays"); + } + } - //Reshape to match the expected shape of the input gradient arrays - this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f'); - this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f'); - if (msg == null || msdx == null) - throw new IllegalStateException("Could not correctly reshape gradient view arrays"); - } + /** + * Get the updated gradient for the given gradient and also update the state of ada delta. + * + * @param gradient the gradient to get the updated gradient for + * @param iteration + */ + @Override + public void applyUpdater(INDArray gradient, int iteration, int epoch) { + if (msg == null || msdx == null) { + throw new IllegalStateException("Updater has not been initialized with view state"); + } - /** - * Get the updated gradient for the given gradient - * and also update the state of ada delta. - * - * @param gradient the gradient to get the - * updated gradient for - * @param iteration - * @return the update gradient - */ - @Override - public void applyUpdater(INDArray gradient, int iteration, int epoch) { - if (msg == null || msdx == null) - throw new IllegalStateException("Updater has not been initialized with view state"); + double rho = config.getRho(); + double epsilon = config.getEpsilon(); - double rho = config.getRho(); - double epsilon = config.getEpsilon(); + //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf + //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t + //Calculate update: + //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t + //Note: negative is applied in the DL4J step function: params -= update rather than params += update + //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 - //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf - //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t - //Calculate update: - //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t - //Note: negative is applied in the DL4J step function: params -= update rather than params += update - //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 - - Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon)); - } + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, + epsilon)); + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java index 7f7d27593..704075e13 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java @@ -92,7 +92,6 @@ public class AdaMaxUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java index 6ad7255af..996c97268 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java @@ -93,7 +93,6 @@ public class AdamUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java index 1cfca0e7d..e32b45f66 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java @@ -48,7 +48,6 @@ public interface GradientUpdater { * * @param gradient the gradient to modify * @param iteration - * @return the modified gradient */ void applyUpdater(INDArray gradient, int iteration, int epoch); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java index 1bc3adcf5..cdb2e39ec 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java @@ -92,7 +92,6 @@ public class NadamUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java index 891bec8a5..d580d8833 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java @@ -77,7 +77,6 @@ public class NesterovsUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java index 03ec92701..dfd953b4f 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java @@ -152,12 +152,12 @@ public class MultiDimensionalMap implements Serializable { /** * Returns the value to which the specified key is mapped, * or {@code null} if this map contains no mapping for the key. - *

      + *

      *

      More formally, if this map contains a mapping from a key * {@code k} to a value {@code v} such that {@code (key==null ? k==null : * key.equals(k))}, then this method returns {@code v}; otherwise * it returns {@code null}. (There can be at most one such mapping.) - *

      + *

      *

      If this map permits null values, then a return value of * {@code null} does not necessarily indicate that the map * contains no mapping for the key; it's also possible that the map @@ -214,15 +214,15 @@ public class MultiDimensionalMap implements Serializable { * from key k to value v such that * (key==null ? k==null : key.equals(k)), that mapping * is removed. (The map can contain at most one such mapping.) - *

      + *

      *

      Returns the value to which this map previously associated the key, * or null if the map contained no mapping for the key. - *

      + *

      *

      If this map permits null values, then a return value of * null does not necessarily indicate that the map * contained no mapping for the key; it's also possible that the map * explicitly mapped the key to null. - *

      + *

      *

      The map will not contain a mapping for the specified key once the * call returns. * diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java index d16c190cb..fdd2bff15 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java @@ -108,12 +108,12 @@ public class MultiDimensionalSet implements Set> { * If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the * elements in the same order. - *

      + *

      *

      The returned array will be "safe" in that no references to it * are maintained by this applyTransformToDestination. (In other words, this method must * allocate a new array even if this applyTransformToDestination is backed by an array). * The caller is thus free to modify the returned array. - *

      + *

      *

      This method acts as bridge between array-based and collection-based * APIs. * @@ -130,27 +130,27 @@ public class MultiDimensionalSet implements Set> { * If the applyTransformToDestination fits in the specified array, it is returned therein. * Otherwise, a new array is allocated with the runtime type of the * specified array and the size of this applyTransformToDestination. - *

      + *

      *

      If this applyTransformToDestination fits in the specified array with room to spare * (i.e., the array has more elements than this applyTransformToDestination), the element in * the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to * null. (This is useful in determining the length of this * applyTransformToDestination only if the caller knows that this applyTransformToDestination does not contain * any null elements.) - *

      + *

      *

      If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the elements * in the same order. - *

      + *

      *

      Like the {@link #toArray()} method, this method acts as bridge between * array-based and collection-based APIs. Further, this method allows * precise control over the runtime type of the output array, and may, * under certain circumstances, be used to save allocation costs. - *

      + *

      *

      Suppose x is a applyTransformToDestination known to contain only strings. * The following code can be used to dump the applyTransformToDestination into a newly allocated * array of String: - *

      + *

      *

            *     String[] y = x.toArray(new String[0]);
      * @@ -181,7 +181,7 @@ public class MultiDimensionalSet implements Set> { * unchanged and returns false. In combination with the * restriction on constructors, this ensures that sets never contain * duplicate elements. - *

      + *

      *

      The stipulation above does not imply that sets must accept all * elements; sets may refuse to add any particular element, including * null, and throw an exception, as described in the diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java index 13780f3a6..240031440 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java @@ -204,9 +204,9 @@ public class ArrayUtil { /** * Credit to mikio braun from jblas - *

      + *

      * Create a random permutation of the numbers 0, ..., size - 1. - *

      + *

      * see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145 */ public static int[] randomPermutation(int size) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java index eb59a2c5f..dfff491e4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java @@ -64,7 +64,7 @@ public class HelperUtils { if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) { if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) { log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName); - helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, new Object[]{arguments}); @@ -76,7 +76,7 @@ public class HelperUtils { ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader(); DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass); try { - helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, arguments); @@ -99,7 +99,7 @@ public class HelperUtils { } } else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) { - helperRet = DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( oneDnnClassName, arguments); log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName); diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 2e587fa8e..b0e7e9b81 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -3,11 +3,14 @@ plugins { id 'maven-publish' } +/* configurations.archives.artifacts.with { archives -> + archives.each { println(it.name) } } +*/ dependencies { //Todo clean this @@ -19,7 +22,7 @@ dependencies { //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + // api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 0a638ff15..1d083f0ce 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -212,7 +212,7 @@ tasks.withType(org.bytedeco.gradle.javacpp.BuildTask) { // Disable the standard javacpp generated tasks and use own // versions below. This allows to build for each variant [javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each { - it.enabled false; + it.enabled false } chipList.each { thisChip -> diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java index 03ec92701..dfd953b4f 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java @@ -152,12 +152,12 @@ public class MultiDimensionalMap implements Serializable { /** * Returns the value to which the specified key is mapped, * or {@code null} if this map contains no mapping for the key. - *

      + *

      *

      More formally, if this map contains a mapping from a key * {@code k} to a value {@code v} such that {@code (key==null ? k==null : * key.equals(k))}, then this method returns {@code v}; otherwise * it returns {@code null}. (There can be at most one such mapping.) - *

      + *

      *

      If this map permits null values, then a return value of * {@code null} does not necessarily indicate that the map * contains no mapping for the key; it's also possible that the map @@ -214,15 +214,15 @@ public class MultiDimensionalMap implements Serializable { * from key k to value v such that * (key==null ? k==null : key.equals(k)), that mapping * is removed. (The map can contain at most one such mapping.) - *

      + *

      *

      Returns the value to which this map previously associated the key, * or null if the map contained no mapping for the key. - *

      + *

      *

      If this map permits null values, then a return value of * null does not necessarily indicate that the map * contained no mapping for the key; it's also possible that the map * explicitly mapped the key to null. - *

      + *

      *

      The map will not contain a mapping for the specified key once the * call returns. * diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java index d16c190cb..fdd2bff15 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java @@ -108,12 +108,12 @@ public class MultiDimensionalSet implements Set> { * If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the * elements in the same order. - *

      + *

      *

      The returned array will be "safe" in that no references to it * are maintained by this applyTransformToDestination. (In other words, this method must * allocate a new array even if this applyTransformToDestination is backed by an array). * The caller is thus free to modify the returned array. - *

      + *

      *

      This method acts as bridge between array-based and collection-based * APIs. * @@ -130,27 +130,27 @@ public class MultiDimensionalSet implements Set> { * If the applyTransformToDestination fits in the specified array, it is returned therein. * Otherwise, a new array is allocated with the runtime type of the * specified array and the size of this applyTransformToDestination. - *

      + *

      *

      If this applyTransformToDestination fits in the specified array with room to spare * (i.e., the array has more elements than this applyTransformToDestination), the element in * the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to * null. (This is useful in determining the length of this * applyTransformToDestination only if the caller knows that this applyTransformToDestination does not contain * any null elements.) - *

      + *

      *

      If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the elements * in the same order. - *

      + *

      *

      Like the {@link #toArray()} method, this method acts as bridge between * array-based and collection-based APIs. Further, this method allows * precise control over the runtime type of the output array, and may, * under certain circumstances, be used to save allocation costs. - *

      + *

      *

      Suppose x is a applyTransformToDestination known to contain only strings. * The following code can be used to dump the applyTransformToDestination into a newly allocated * array of String: - *

      + *

      *

            *     String[] y = x.toArray(new String[0]);
      * @@ -181,7 +181,7 @@ public class MultiDimensionalSet implements Set> { * unchanged and returns false. In combination with the * restriction on constructors, this ensures that sets never contain * duplicate elements. - *

      + *

      *

      The stipulation above does not imply that sets must accept all * elements; sets may refuse to add any particular element, including * null, and throw an exception, as described in the diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java index 13780f3a6..240031440 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java @@ -204,9 +204,9 @@ public class ArrayUtil { /** * Credit to mikio braun from jblas - *

      + *

      * Create a random permutation of the numbers 0, ..., size - 1. - *

      + *

      * see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145 */ public static int[] randomPermutation(int size) { diff --git a/settings.gradle b/settings.gradle index 2e4e68cce..aaf58f336 100644 --- a/settings.gradle +++ b/settings.gradle @@ -148,7 +148,6 @@ include ':cavis-ui:cavis-ui-standalone' include ':cavis-ui:cavis-ui-vertx' include ':cavis-zoo' include ':cavis-zoo:cavis-zoo-models' - include ':brutex-extended-tests' include ':cavis-full'