diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index 7d5dbf4fc..3487cc216 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -48,6 +48,10 @@ public class BiasAdd extends DynamicCustomOp { this.nchw = nchw; } + public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, boolean nchw){ + this(input, bias, null, nchw); + } + public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){ super(new INDArray[]{input, bias}, wrapOrNull(output)); bArguments.clear(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java index d2046140a..8d660eba8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java @@ -54,7 +54,12 @@ public class FirstIndex extends BaseIndexAccumulation { public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) { + this(x, condition, false, dimension); + } + + public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) { this(x, condition, Nd4j.EPS_THRESHOLD, dimension); + this.keepDims = keepDims; } public FirstIndex(INDArray x, @NonNull Condition condition, double eps, int... dimension) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java index 30f51c56c..4c8465ef7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java @@ -38,7 +38,12 @@ public class IAMax extends BaseIndexAccumulation { public IAMax() {} public IAMax(INDArray x, int... dimensions) { + this(x, false, dimensions); + } + + public IAMax(INDArray x, boolean keepDims, int... dimensions) { this(x, null, dimensions); + this.keepDims = keepDims; } public IAMax(INDArray x, INDArray z, int... dimensions) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java index 5a3e950e1..0a1383a67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java @@ -41,6 +41,11 @@ public class IAMin extends BaseIndexAccumulation { super(x, dimensions); } + public IAMin(INDArray in, boolean keepDims, int... dimnesions){ + super(in, null, dimnesions); + this.keepDims = keepDims; + } + public IAMin(INDArray x, INDArray z, int... dimensions) { super(x, z, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java index 792547d7c..b29af5042 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java @@ -58,6 +58,11 @@ public class LastIndex extends BaseIndexAccumulation { this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); } + public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) { + this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); + this.keepDims = keepDim; + } + public LastIndex(INDArray x, @NonNull Condition condition, double eps, int... dimensions) { super(x,null, dimensions); this.condition = condition; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index 20ff5918c..e3716bc24 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -76,6 +76,15 @@ public class BatchNorm extends DynamicCustomOp { addArgs(); } + public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){ + super(wrapFilterNull(input, mean, variance, gamma, beta), null); + this.jaxis = axis; + this.applyBeta = beta != null; + this.applyGamma = gamma != null; + this.epsilon = epsilon; + addArgs(); + } + public void addArgs() { addIArgument(ArrayUtil.fromBoolean(applyGamma)); addIArgument(ArrayUtil.fromBoolean(applyBeta)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index c7aef3c62..152b93980 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.reduce; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -40,6 +41,12 @@ public class Moments extends DynamicCustomOp { private int[] axes; + public Moments(@NonNull INDArray input, int... axes){ + super(new INDArray[]{input}, null); + this.axes = axes; + addArgs(); + } + public Moments(SameDiff sameDiff, SDVariable input) { this(sameDiff, input, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java index 945cd505d..be33a458d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java @@ -47,6 +47,12 @@ public class NormalizeMoments extends DynamicCustomOp { addArgs(); } + public NormalizeMoments(INDArray counts, INDArray means, INDArray variances, double shift) { + super(null, new INDArray[]{counts, means, variances}, null); + this.shift = shift; + addArgs(); + } + public NormalizeMoments(INDArray counts, INDArray ssSum, INDArray ssSqSum, INDArray outMean, INDArray outVar) { super(null, new INDArray[]{counts, ssSum, ssSqSum}, new INDArray[]{outMean, outVar}, new ArrayList(), new ArrayList()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java index 18009a466..42ecc2f57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java @@ -17,11 +17,13 @@ package org.nd4j.linalg.api.ops.impl.reduce; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -36,10 +38,13 @@ import java.util.List; public class ZeroFraction extends DynamicCustomOp { public ZeroFraction(SameDiff sameDiff, SDVariable input) { - super(null, sameDiff, new SDVariable[] {input}, false); } + public ZeroFraction(@NonNull INDArray input){ + super(new INDArray[]{input}, null); + } + @Override public String opName() { return "zero_fraction"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java index 32c07ad96..f9e30be9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java @@ -45,6 +45,10 @@ public class PRelu extends DynamicCustomOp { addIArgument(sharedAxes); } + public PRelu(@NonNull INDArray x, @NonNull INDArray alpha, @NonNull int... sharedAxes) { + this(x, null, alpha, sharedAxes); + } + public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) { super(new INDArray[]{x, alpha}, new INDArray[]{z}); this.sharedAxes = sharedAxes; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java index 6275ce210..f21a0d291 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.val; import org.apache.commons.lang3.NotImplementedException; import org.nd4j.autodiff.samediff.SDVariable; @@ -23,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -41,6 +43,35 @@ public class ConfusionMatrix extends DynamicCustomOp { public ConfusionMatrix(){ } + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){ + super(new INDArray[]{labels, predicted}, null); + this.outputType = dataType; + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){ + this(labels, predicted, numClasses, DEFAULT_DTYPE); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights) { + this(labels, predicted, weights, null); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses) { + this(labels, predicted, weights, numClasses, DEFAULT_DTYPE); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, Integer numClasses, @NonNull DataType dataType) { + this(labels, predicted, null, numClasses, dataType); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses, @NonNull DataType dataType) { + super(wrapFilterNull(labels, predicted, weights), null); + this.outputType = dataType; + if(numClasses != null) { + addIArgument(numClasses); + } + } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){ super(null, sameDiff, new SDVariable[]{labels, pred}); this.outputType = dataType; @@ -57,7 +88,9 @@ public class ConfusionMatrix extends DynamicCustomOp { public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){ super(null, sameDiff, new SDVariable[]{labels, pred, weights}); - addIArgument(numClasses); + if(numClasses != null) { + addIArgument(numClasses); + } } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java index f45f9aa87..3e94cb126 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java @@ -44,13 +44,16 @@ public class Cross extends DynamicCustomOp { public Cross() { } - public Cross(SameDiff sameDiff, SDVariable[] args) { super(null, sameDiff, args, false); } + public Cross(INDArray a, INDArray b){ + this(a,b,null); + } + public Cross(INDArray a, INDArray b, INDArray out){ - super(null, new INDArray[]{a,b}, out == null ? null : new INDArray[]{out}, null, (int[])null); + super(null, new INDArray[]{a,b}, wrapOrNull(out), null, (int[])null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java index b6d08784b..95947a94b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -44,8 +45,12 @@ public class Diag extends DynamicCustomOp { public Diag() { } - public Diag(INDArray[] inputs, INDArray[] outputs) { - super(null, inputs, outputs); + public Diag(@NonNull INDArray input) { + this(input, null); + } + + public Diag(@NonNull INDArray input, @NonNull INDArray output){ + super(null, new INDArray[]{input}, wrapOrNull(output)); } public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java index 6b1688602..1d2b93d9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java @@ -51,6 +51,10 @@ public class DiagPart extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } + public DiagPart(INDArray in){ + this(in, null); + } + public DiagPart(INDArray in, INDArray out){ super(null, in, out, null, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java index 3472be2de..3a8bb8f15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.shade.guava.base.Preconditions; import java.util.Collections; import java.util.List; @@ -55,6 +58,23 @@ public class Eye extends DynamicCustomOp { public Eye() { } + public Eye(@NonNull INDArray rows){ + this(rows.getInt(0)); + Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar"); + } + + public Eye(@NonNull INDArray rows, @NonNull INDArray columns){ + this(rows.getInt(0), columns.getInt(0)); + Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar"); + Preconditions.checkArgument(columns.isScalar(), "Columns INDArray must be a scalar"); + } + + public Eye(int rows){ + this.numRows = rows; + this.numCols = rows; + addArgs(); + } + public Eye(SameDiff sameDiff, SDVariable numRows){ super(null, sameDiff, new SDVariable[] {numRows}, false); } @@ -66,10 +86,7 @@ public class Eye extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {numRows, numCols, batch_shape}, false); } public Eye(SameDiff sameDiff, int numRows) { - super(null, sameDiff, new SDVariable[] {}, false); - this.numRows = numRows; - this.numCols = numRows; - addArgs(); + this(sameDiff, numRows, numRows); } public Eye(SameDiff sameDiff, int numRows, int numCols) { @@ -77,13 +94,25 @@ public class Eye extends DynamicCustomOp { } public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType) { - super(null, sameDiff, new SDVariable[] {}, false); + this(sameDiff, numRows, numCols, dataType, null); + } + + public Eye(int numRows, int numCols, DataType dataType, int[] batchDimension) { this.numRows = numRows; this.numCols = numCols; + this.batchDimension = batchDimension; this.dataType = dataType; addArgs(); } + public Eye(int numRows, int numCols) { + this(numRows, numCols, DEFAULT_DTYPE); + } + + public Eye(int numRows, int numCols, DataType dataType) { + this(numRows, numCols, dataType, null); + } + public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType, int[] batchDimension) { super(null, sameDiff, new SDVariable[] {}, false); this.numRows = numRows; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java index 448ae1d16..b63052eb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -36,6 +37,10 @@ import java.util.Map; @Slf4j public class MergeAvg extends DynamicCustomOp { + public MergeAvg(INDArray... inputs){ + super(inputs, null); + } + public MergeAvg(SameDiff sameDiff, SDVariable... inputs) { super(null, sameDiff, inputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java index 11578b902..6c342200d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -40,6 +41,10 @@ public class MergeMax extends DynamicCustomOp { super(null, sameDiff, inputs); } + public MergeMax(INDArray... inputs){ + super(inputs, null); + } + public MergeMax(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java index 6a8f2965d..fcd220004 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java @@ -17,9 +17,11 @@ package org.nd4j.linalg.api.ops.impl.transforms; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB; import java.util.Collections; @@ -37,7 +39,10 @@ public class ReluLayer extends XwPlusB { public ReluLayer(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) { super(sameDiff, input, weights, bias); + } + public ReluLayer(@NonNull INDArray input, @NonNull INDArray weights, @NonNull INDArray bias){ + super(new INDArray[]{input, weights, bias}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java index 4d6ba3e66..026930e4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java @@ -49,8 +49,12 @@ public class ClipByNorm extends DynamicCustomOp { addTArgument(clipValue); } + public ClipByNorm(INDArray in, double clipValue, int... dimensions){ + this(in, null, clipValue, dimensions); + } + public ClipByNorm(INDArray in, INDArray out, double clipValue, int... dimensions){ - super(null, new INDArray[]{in}, (out == null ? null : new INDArray[]{out}), Collections.singletonList(clipValue), dimensions); + super(null, new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index 11d3e9004..3927ba2bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.clip; +import lombok.NonNull; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -38,11 +39,10 @@ public class ClipByValue extends DynamicCustomOp { private double clipValueMin; private double clipValueMax; - public ClipByValue(INDArray[] inputs, INDArray[] outputs, double clipValueMin, double clipValueMax, boolean inPlace) { - super(null, inputs, outputs); + public ClipByValue(@NonNull INDArray input, double clipValueMin, double clipValueMax) { + super(null, new INDArray[]{input}, null); this.clipValueMin = clipValueMin; this.clipValueMax = clipValueMax; - this.inplaceCall = inPlace; addTArgument(clipValueMin, clipValueMax); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java index 8a782acf6..d6230e153 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java @@ -41,13 +41,22 @@ public class ATan2 extends BaseDynamicTransformOp { super(sameDiff, new SDVariable[] {y, x} ,false); } + /** + * Note that the order of x and y match {@link java.lang.Math#atan2(double, double)}, + * and are reversed when compared to OldATan2. + * See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)} + */ + public ATan2(INDArray x, INDArray y) { + this(x,y,null); + } + /** * Note that the order of x and y match {@link java.lang.Math#atan2(double, double)}, * and are reversed when compared to OldATan2. * See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)} */ public ATan2(INDArray x, INDArray y, INDArray z) { - super(new INDArray[]{x, y}, new INDArray[]{ z }); + super(new INDArray[]{x, y}, wrapOrNull(z)); } public ATan2() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java index cf72ea7be..d3a5c9676 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java @@ -17,10 +17,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -49,6 +51,18 @@ public class DotProductAttention extends DynamicCustomOp { addIArgument(withWeights ? 1 : 0); } + public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled){ + this(queries, keys, values, mask, scaled, false); + } + + public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled, boolean withWeights){ + super(wrapFilterNull(queries, keys, values, mask), null); + this.scaled = scaled; + this.withWeights = withWeights; + addIArgument(scaled ? 1 : 0); + addIArgument(withWeights ? 1 : 0); + } + @Override public String opName() { return "dot_product_attention"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java index 08a9b0faf..83cad14a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -40,8 +41,12 @@ public class IsNonDecreasing extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } - public IsNonDecreasing(INDArray[] inputs, INDArray[] outputs) { - super(null, inputs, outputs); + public IsNonDecreasing(@NonNull INDArray input){ + this(input, null); + } + + public IsNonDecreasing(@NonNull INDArray input, INDArray output) { + super(null, new INDArray[]{input}, wrapOrNull(output)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java index 02a527cb8..55b866cad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -38,8 +39,12 @@ public class IsStrictlyIncreasing extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } - public IsStrictlyIncreasing( INDArray[] inputs, INDArray[] outputs) { - super(null, inputs, outputs); + public IsStrictlyIncreasing(@NonNull INDArray input){ + this(input, null); + } + + public IsStrictlyIncreasing(@NonNull INDArray input, INDArray output) { + super(null, new INDArray[]{input}, wrapOrNull(output)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java index 7c7c34fc5..0c4990bb2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java @@ -62,6 +62,10 @@ public class LayerNorm extends DynamicCustomOp { setDimensions(dimensions); } + public LayerNorm(@NonNull INDArray input, @NonNull INDArray gain, boolean channelsFirst, int... dimensions) { + this(input, gain, null, channelsFirst, dimensions); + } + public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) { this(input, gain, null, result, channelsFirst, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java index 09a4823e0..86c9d9c0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java @@ -52,6 +52,11 @@ public class LogSoftMax extends DynamicCustomOp { this(x, x); } + public LogSoftMax(INDArray x, int dimension) { + this(x, null); + this.dimension = dimension; + } + public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) { this(sameDiff, i_v); this.dimension = dimension; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java index c079b0fc0..67ba9f343 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -39,6 +41,10 @@ public class MatrixDeterminant extends DynamicCustomOp { // } + public MatrixDeterminant(@NonNull INDArray input){ + super(new INDArray[]{input}, null); + } + public MatrixDeterminant(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(null, sameDiff, new SDVariable[]{in}, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java index 0bbe7f25d..4ff0f942b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -36,6 +38,10 @@ public class MatrixInverse extends DynamicCustomOp { // } + public MatrixInverse(@NonNull INDArray input){ + super(new INDArray[]{input}, null); + } + public MatrixInverse(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(null, sameDiff, new SDVariable[]{in}, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java index 3d00afe5b..9bbf6c50f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -32,6 +34,10 @@ public class MatrixSetDiag extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in, diag}, inPlace); } + public MatrixSetDiag(@NonNull INDArray in, @NonNull INDArray diag){ + super(new INDArray[]{in, diag}, null); + } + public MatrixSetDiag(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java index f55b21263..54167bd8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java @@ -17,10 +17,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -54,6 +56,22 @@ public class MultiHeadDotProductAttention extends DynamicCustomOp { addIArgument(withWeights ? 1 : 0); } + public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, + @NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo, + INDArray mask, boolean scaled) { + this(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false); + } + + public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, + @NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo, + INDArray mask, boolean scaled, boolean withWeights) { + super(wrapFilterNull(queries, keys, values, Wq, Wk, Wv, Wo, mask), null); + this.scaled = scaled; + this.withWeights = withWeights; + addIArgument(scaled ? 1 : 0); + addIArgument(withWeights ? 1 : 0); + } + @Override public String opName() { return "multi_head_dot_product_attention"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java index 07e72b6b3..df41438fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -39,6 +41,10 @@ public class Pow extends DynamicCustomOp { public Pow(){ } + public Pow(@NonNull INDArray x, @NonNull INDArray y){ + super(new INDArray[]{x,y}, null); + } + @Override public String opName(){ return "Pow"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index bfa1c27c1..d8db2569c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -69,8 +70,12 @@ public class SoftMax extends BaseDynamicTransformOp { addIArgument(dimension); } + public SoftMax(@NonNull INDArray input, int dimension){ + this(input, null, dimension); + } + public SoftMax(INDArray input, INDArray result, int dimension){ - super(new INDArray[]{input}, new INDArray[]{result}); + super(new INDArray[]{input}, wrapOrNull(result)); this.dimension = dimension; addIArgument(dimension); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java index 140aef355..467b36a4e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java @@ -34,8 +34,12 @@ public class Standardize extends DynamicCustomOp { setDimensions(dimensions); } + public Standardize(INDArray input, int... dimensions){ + this(input, null, dimensions); + } + public Standardize(INDArray input, INDArray result, int... dimensions){ - super("standardize", new INDArray[]{input}, new INDArray[]{result}); + super("standardize", new INDArray[]{input},wrapOrNull(result)); setDimensions(dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java index e43918918..9d61de1c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -37,6 +39,10 @@ public class Trace extends DynamicCustomOp { super(null, sd, new SDVariable[]{in}); } + public Trace(@NonNull INDArray in){ + super(wrapOrNull(in), null); + } + public Trace(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java index 0c10159e3..563c4a7f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java @@ -46,7 +46,14 @@ public class XwPlusB extends DynamicCustomOp { public XwPlusB(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) { super(null, sameDiff, new SDVariable[] {input, weights, bias}, false); + } + public XwPlusB(INDArray input, INDArray weights, INDArray bias) { + super(new INDArray[] {input, weights, bias}, null); + } + + public XwPlusB(INDArray[] inputs, INDArray output){ + super(inputs, wrapOrNull(output)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java index baaa87e1f..202f7e291 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java index 9c4d478c7..b47d41462 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -35,6 +36,10 @@ public class SigmoidDerivative extends DynamicCustomOp { super(sameDiff, new SDVariable[]{i_v1, i_v2}); } + public SigmoidDerivative(@NonNull INDArray x, @NonNull INDArray y) { + this(x, y, null); + } + public SigmoidDerivative(INDArray x, INDArray y, INDArray z) { super(null, new INDArray[]{x,y}, new INDArray[]{z}, null, (int[])null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java index dbbdb8dde..37a1d8632 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java @@ -42,6 +42,10 @@ public class SoftmaxBp extends DynamicCustomOp { addIArgument(dimension); } + public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, Integer dimension){ + this(input, grad, null, dimension); + } + public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){ super(new INDArray[]{input, grad}, wrapOrNull(output)); if(dimension != null) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index c2d15df19..f64bfe902 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -41,6 +42,10 @@ public class MergeAddOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public MergeAddOp(@NonNull INDArray... inputs){ + this(inputs, null); + } + public MergeAddOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java index e67c362fc..5b9faa005 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; @@ -47,6 +48,10 @@ public class RandomExponential extends DynamicCustomOp { addTArgument(lambda); } + public RandomExponential(double lambda, DataType datatype, long... shape){ + this(Nd4j.createFromArray(shape), Nd4j.createUninitialized(datatype, shape), lambda); + } + public RandomExponential(INDArray shape,INDArray out, double lambda){ super(null, new INDArray[]{shape}, new INDArray[]{out}, Collections.singletonList(lambda), (List)null); this.lambda = lambda; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java index dec04f11f..0ffc8e72e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -49,6 +50,10 @@ public class BernoulliDistribution extends BaseRandomOp { super(); } + public BernoulliDistribution(double p, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), p); + } + /** * This op fills Z with bernoulli trial results, so 0, or 1, depending by common probability * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java index 69a5460f2..41bf909cc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -46,6 +47,10 @@ public class BinomialDistribution extends BaseRandomOp { this.extraArgs = new Object[] {(double) this.trials, this.probability}; } + public BinomialDistribution(int trials, double probability, DataType dt, long[] shape){ + this(Nd4j.createUninitialized(dt, shape), trials, probability); + } + public BinomialDistribution() { super(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java index 42f6def0f..0bb41b655 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -50,6 +51,10 @@ public class GaussianDistribution extends BaseRandomOp { super(); } + public GaussianDistribution(double mean, double stddev, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), mean, stddev); + } + /** * This op fills Z with random values within stddev..mean..stddev boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java index 4a0b36b32..b42e311a7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -50,6 +51,10 @@ public class LogNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public LogNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), mean, stddev); + } + /** * This op fills Z with random values within stddev..mean..stddev boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java index bd453fe0a..24e52a532 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; @@ -48,6 +49,10 @@ public class TruncatedNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public TruncatedNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), mean, stddev); + } + /** * This op fills Z with random values within stddev..mean..stddev boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java index 2b4adfc1a..408af9ce2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; @@ -46,6 +47,10 @@ public class UniformDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.from, this.to}; } + public UniformDistribution(double min, double max, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), min, max); + } + /** * This op fills Z with random values within from...to boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java new file mode 100644 index 000000000..f60726c36 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java @@ -0,0 +1,236 @@ +/* ***************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.factory; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Arrays; + +public class NDValidation { + + private NDValidation() { + } + + /** + * Validate that the operation is being applied on a numerical INDArray (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to perform operation on + */ + public static void validateNumerical(String opName, INDArray v) { + if (v == null) + return; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-numerical data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on numerical INDArrays (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to perform operation on + */ + public static void validateNumerical(String opName, INDArray[] v) { + if (v == null) + return; + for (int i = 0; i < v.length; i++) { + if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input array " + i + " with non-numerical data type " + v[i].dataType()); + } + } + + /** + * Validate that the operation is being applied on a numerical INDArray (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateNumerical(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type;" + + " got array with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on numerical INDArrays (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to perform operation on + */ + public static void validateNumerical(String opName, String inputName, INDArray[] v) { + if (v == null) + return; + for (int i = 0; i < v.length; i++) { + if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input \"" + inputName + "\" array " + i + " with non-numerical data type " + v[i].dataType()); + } + } + + /** + * Validate that the operation is being applied on numerical INDArrays (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v1 Variable to validate datatype for (input to operation) + * @param v2 Variable to validate datatype for (input to operation) + */ + public static void validateNumerical(String opName, INDArray v1, INDArray v2) { + if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on arrays if one or both variables" + + " are non-numerical: got " + v1.dataType() + " and " + v2.dataType()); + } + + /** + * Validate that the operation is being applied on an integer type INDArray + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateInteger(String opName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on an integer type INDArray + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateInteger(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" + + " type; got array with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on an floating point type INDArray + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateFloatingPoint(String opName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isFPType()) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-floating point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a floating point type INDArray + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateFloatingPoint(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isFPType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + + "\" must be an floating point type; got array with non-floating point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a boolean type INDArray + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateBool(String opName, INDArray v) { + if (v == null) + return; + if (v.dataType() != DataType.BOOL) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-boolean point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a boolean type INDArray + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateBool(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (v.dataType() != DataType.BOOL) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + + "\" must be an boolean variable; got array with non-boolean data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on boolean INDArrays + * + * @param opName Operation name to print in the exception + * @param v1 Variable to validate datatype for (input to operation) + * @param v2 Variable to validate datatype for (input to operation) + */ + public static void validateBool(String opName, INDArray v1, INDArray v2) { + if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL) + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on array if one or both variables are non-boolean: " + + v1.dataType() + " and " + v2.dataType()); + } + + /** + * Validate that the operation is being applied on array with the exact same datatypes (which may optionally be + * restricted to numerical INDArrays only (not boolean or utf8)) + * + * @param opName Operation name to print in the exception + * @param numericalOnly If true, the variables must all be the same type, and must be numerical (not boolean/utf8) + * @param vars Variable to perform operation on + */ + public static void validateSameType(String opName, boolean numericalOnly, INDArray... vars) { + if (vars.length == 0) + return; + if (vars.length == 1) { + if (numericalOnly) { + validateNumerical(opName, vars[0]); + } + } else { + DataType first = vars[0].dataType(); + if (numericalOnly) + validateNumerical(opName, vars[0]); + for (int i = 1; i < vars.length; i++) { + if (first != vars[i].dataType()) { + DataType[] dtypes = new DataType[vars.length]; + for (int j = 0; j < vars.length; j++) { + dtypes[j] = vars[j].dataType(); + } + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to arrays with different datatypes:" + + " Got arrays with datatypes " + Arrays.toString(dtypes)); + } + } + } + } + + public static boolean isSameType(INDArray x, INDArray y) { + return x.dataType() == y.dataType(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 5e62dd198..2e2efadda 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -16,6 +16,10 @@ package org.nd4j.linalg.factory; +import org.nd4j.linalg.factory.ops.NDBitwise; +import org.nd4j.linalg.factory.ops.NDMath; +import org.nd4j.linalg.factory.ops.NDNN; +import org.nd4j.linalg.factory.ops.NDRandom; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; @@ -114,6 +118,51 @@ import java.util.logging.Logger; */ public class Nd4j { + /** + * Bitwise namespace - operations related to bitwise manipulation of arrays + */ + public static final NDBitwise bitwise = new NDBitwise(); + /** + * Math namespace - general mathematical operations + */ + public static final NDMath math = new NDMath(); + /** + * Random namespace - (pseudo) random number generation methods + */ + public static final NDRandom random = new NDRandom(); + /** + * Neural network namespace - operations related to neural networks + */ + public static final NDNN nn = new NDNN(); + + /** + * Bitwise namespace - operations related to bitwise manipulation of arrays + */ + public static NDBitwise bitwise() { + return bitwise; + } + + /** + * Math namespace - general mathematical operations + */ + public static NDMath math() { + return math; + } + + /** + * Random namespace - (pseudo) random number generation methods + */ + public static NDRandom random() { + return random; + } + + /** + * Neural network namespace - operations related to neural networks + */ + public static NDNN nn() { + return nn; + } + private final static String DATA_BUFFER_OPS = "databufferfactory"; private final static String CONVOLUTION_OPS = "convops"; /**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/ @@ -2638,7 +2687,7 @@ public class Nd4j { INDArray ret; if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) { ret = Nd4j.create(x.dataType(), x.length(), x.length()); - Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret})); + Nd4j.getExecutioner().execAndReturn(new Diag(x, ret)); } else { ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1))); Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java new file mode 100644 index 000000000..f77d5c823 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java @@ -0,0 +1,211 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDBitwise { + public NDBitwise() { + } + + /** + * Bitwise AND operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y Second input array (INT type) + * @return output Bitwise AND array (INT type) + */ + public INDArray and(INDArray x, INDArray y) { + NDValidation.validateInteger("and", "x", x); + NDValidation.validateInteger("and", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(x, y))[0]; + } + + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitRotl(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitRotl", "x", x); + NDValidation.validateInteger("bitRotl", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0]; + } + + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitRotr(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitRotr", "x", x); + NDValidation.validateInteger("bitRotr", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0]; + } + + /** + * Shift integer bits to the left, i.e. var << 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitShift(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitShift", "x", x); + NDValidation.validateInteger("bitShift", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0]; + } + + /** + * Shift integer bits to the right, i.e. var >> 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitShiftRight(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitShiftRight", "x", x); + NDValidation.validateInteger("bitShiftRight", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0]; + } + + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * + * @param x First input array. (INT type) + * @param y Second input array. (INT type) + * @return output bitwise Hamming distance (INT type) + */ + public INDArray bitsHammingDistance(INDArray x, INDArray y) { + NDValidation.validateInteger("bitsHammingDistance", "x", x); + NDValidation.validateInteger("bitsHammingDistance", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(x, y))[0]; + } + + /** + * Bitwise left shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public INDArray leftShift(INDArray x, INDArray y) { + NDValidation.validateInteger("leftShift", "x", x); + NDValidation.validateInteger("leftShift", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, y))[0]; + } + + /** + * Bitwise left cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code leftShiftCyclic(01110000, 2) -> 11000001}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public INDArray leftShiftCyclic(INDArray x, INDArray y) { + NDValidation.validateInteger("leftShiftCyclic", "x", x); + NDValidation.validateInteger("leftShiftCyclic", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, y))[0]; + } + + /** + * Bitwise OR operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise OR array (INT type) + */ + public INDArray or(INDArray x, INDArray y) { + NDValidation.validateInteger("or", "x", x); + NDValidation.validateInteger("or", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(x, y))[0]; + } + + /** + * Bitwise right shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public INDArray rightShift(INDArray x, INDArray y) { + NDValidation.validateInteger("rightShift", "x", x); + NDValidation.validateInteger("rightShift", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, y))[0]; + } + + /** + * Bitwise right cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code rightShiftCyclic(00001110, 2) -> 10000011}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public INDArray rightShiftCyclic(INDArray x, INDArray y) { + NDValidation.validateInteger("rightShiftCyclic", "x", x); + NDValidation.validateInteger("rightShiftCyclic", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, y))[0]; + } + + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise XOR array (INT type) + */ + public INDArray xor(INDArray x, INDArray y) { + NDValidation.validateInteger("xor", "x", x); + NDValidation.validateInteger("xor", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(x, y))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java new file mode 100644 index 000000000..8e194fcd4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -0,0 +1,1324 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.conditions.Condition; + +public class NDMath { + public NDMath() { + } + + /** + * Elementwise absolute value operation: out = abs(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray abs(INDArray x) { + NDValidation.validateNumerical("abs", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(x)); + } + + /** + * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray acos(INDArray x) { + NDValidation.validateNumerical("acos", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(x)); + } + + /** + * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray acosh(INDArray x) { + NDValidation.validateNumerical("acosh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(x)); + } + + /** + * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray amax(INDArray in, int... dimensions) { + NDValidation.validateNumerical("amax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(in, dimensions)); + } + + /** + * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray amean(INDArray in, int... dimensions) { + NDValidation.validateNumerical("amean", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(in, dimensions)); + } + + /** + * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray amin(INDArray in, int... dimensions) { + NDValidation.validateNumerical("amin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(in, dimensions)); + } + + /** + * Boolean AND operation: elementwise (x != 0) && (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public INDArray and(INDArray x, INDArray y) { + NDValidation.validateBool("and", "x", x); + NDValidation.validateBool("and", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(x, y)); + } + + /** + * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray asin(INDArray x) { + NDValidation.validateNumerical("asin", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(x)); + } + + /** + * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray asinh(INDArray x) { + NDValidation.validateNumerical("asinh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(x)); + } + + /** + * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray asum(INDArray in, int... dimensions) { + NDValidation.validateNumerical("asum", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(in, dimensions)); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray atan(INDArray x) { + NDValidation.validateNumerical("atan", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(x)); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
+ * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray atan2(INDArray y, INDArray x) { + NDValidation.validateNumerical("atan2", "y", y); + NDValidation.validateNumerical("atan2", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(y, x))[0]; + } + + /** + * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray atanh(INDArray x) { + NDValidation.validateNumerical("atanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x)); + } + + /** + * Element-wise ceiling function: out = ceil(x).
+ * Rounds each value up to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray ceil(INDArray x) { + NDValidation.validateNumerical("ceil", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(x)); + } + + /** + * Clipping by L2 norm, optionally along dimension(s)
+ * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
+ * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
+ * to the corresponding l2Norm along the specified dimensions
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray clipByNorm(INDArray x, double clipValue, int... dimensions) { + NDValidation.validateNumerical("clipByNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue, dimensions))[0]; + } + + /** + * Element-wise clipping function:
+ * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
+ * out[i] = clipValueMin if in[i] < clipValueMin
+ * out[i] = clipValueMax if in[i] > clipValueMax
+ * + * @param x Input variable (NUMERIC type) + * @param clipValueMin Minimum value for clipping + * @param clipValueMax Maximum value for clipping + * @return output Output variable (NUMERIC type) + */ + public INDArray clipByValue(INDArray x, double clipValueMin, double clipValueMax) { + NDValidation.validateNumerical("clipByValue", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
+ * [1, 0, 0]
+ * [0, 1, 1]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param dataType Data type + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, DataType dataType) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, dataType))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
+ * [1, 0, 0, 0]
+ * [0, 1, 1, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param numClasses Number of classes + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, int numClasses) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, numClasses))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
+ * [1, 0, 0]
+ * [0, 3, 2]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, INDArray weights) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + NDValidation.validateNumerical("confusionMatrix", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, weights))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
+ * [1, 0, 0, 0]
+ * [0, 3, 2, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param numClasses + * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, INDArray weights, + int numClasses) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + NDValidation.validateNumerical("confusionMatrix", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, weights, numClasses))[0]; + } + + /** + * Elementwise cosine operation: out = cos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cos(INDArray x) { + NDValidation.validateNumerical("cos", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(x)); + } + + /** + * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cosh(INDArray x) { + NDValidation.validateNumerical("cosh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(x)); + } + + /** + * Cosine distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = 1.0 - cosineSimilarity(x,y)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("cosineDistance", "x", x); + NDValidation.validateNumerical("cosineDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions)); + } + + /** + * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
+ * along the specified dimensions:
+ * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("cosineSimilarity", "x", x); + NDValidation.validateNumerical("cosineSimilarity", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions)); + } + + /** + * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray countNonZero(INDArray in, int... dimensions) { + NDValidation.validateNumerical("countNonZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(in, dimensions)); + } + + /** + * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray countZero(INDArray in, int... dimensions) { + NDValidation.validateNumerical("countZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(in, dimensions)); + } + + /** + * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
+ * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) + * @return output Element-wise cross product (NUMERIC type) + */ + public INDArray cross(INDArray a, INDArray b) { + NDValidation.validateNumerical("cross", "a", a); + NDValidation.validateNumerical("cross", "b", b); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Cross(a, b))[0]; + } + + /** + * Element-wise cube function: out = x^3
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cube(INDArray x) { + NDValidation.validateNumerical("cube", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(x)); + } + + /** + * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
+ * For example, if input = [1,2,3], then output is given by:
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ *
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
+ * i.e., for input rank R, output has rank 2R
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray diag(INDArray x) { + NDValidation.validateNumerical("diag", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Diag(x))[0]; + } + + /** + * Extract the diagonal part from the input array.
+ * If input is
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ * then output is [1, 2, 3].
+ * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * + * @param x Input variable (NUMERIC type) + * @return output Diagonal part of the input (NUMERIC type) + */ + public INDArray diagPart(INDArray x) { + NDValidation.validateNumerical("diagPart", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.DiagPart(x))[0]; + } + + /** + * Entropy reduction: -sum(x * log(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray entropy(INDArray in, int... dimensions) { + NDValidation.validateNumerical("entropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(in, dimensions)); + } + + /** + * Element-wise Gaussian error function - out = erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray erf(INDArray x) { + NDValidation.validateNumerical("erf", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(x)); + } + + /** + * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray erfc(INDArray x) { + NDValidation.validateNumerical("erfc", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(x)); + } + + /** + * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("euclideanDistance", "x", x); + NDValidation.validateNumerical("euclideanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions)); + } + + /** + * Elementwise exponent function: out = exp(x) = 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray exp(INDArray x) { + NDValidation.validateNumerical("exp", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(x)); + } + + /** + * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray expm1(INDArray x) { + NDValidation.validateNumerical("expm1", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(x)); + } + + /** + * Generate an identity matrix with the specified number of rows and columns.
+ * + * @param rows Number of rows + * @return output Identity matrix (NUMERIC type) + */ + public INDArray eye(int rows) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows))[0]; + } + + /** + * As per eye(String, int, int, DataType) but with the default datatype, Eye.DEFAULT_DTYPE
+ * + * @param rows Number of rows + * @param cols Number of columns + * @return output (NUMERIC type) + */ + public INDArray eye(int rows, int cols) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols))[0]; + } + + /** + * Generate an identity matrix with the specified number of rows and columns
+ * Example:
+ *

+ * {@code INDArray eye = eye(3,2)
+ * eye:
+ * [ 1, 0]
+ * [ 0, 1]
+ * [ 0, 0]}
+ *

+ * + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @return output Identity matrix (NUMERIC type) + */ + public INDArray eye(int rows, int cols, DataType dataType) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols, dataType))[0]; + } + + /** + * As per eye(int, int) bit with the number of rows/columns specified as scalar INDArrays
+ * + * @param rows Number of rows (INT type) + * @param cols Number of columns (INT type) + * @return output Identity matrix (NUMERIC type) + */ + public INDArray eye(INDArray rows, INDArray cols) { + NDValidation.validateInteger("eye", "rows", rows); + NDValidation.validateInteger("eye", "cols", cols); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols))[0]; + } + + /** + * As per eye(String, int) but with the number of rows specified as a scalar INDArray
+ * + * @param rows Number of rows (INT type) + * @return output SDVaribable identity matrix (NUMERIC type) + */ + public INDArray eye(INDArray rows) { + NDValidation.validateInteger("eye", "rows", rows); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows))[0]; + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray firstIndex(INDArray in, Condition condition, int... dimensions) { + NDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, dimensions)); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray firstIndex(INDArray in, Condition condition, boolean keepDims, + int... dimensions) { + NDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, keepDims, dimensions)); + } + + /** + * Element-wise floor function: out = floor(x).
+ * Rounds each value down to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floor(INDArray x) { + NDValidation.validateNumerical("floor", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(x)); + } + + /** + * Hamming distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = count( x[i] != y[i] )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("hammingDistance", "x", x); + NDValidation.validateNumerical("hammingDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions)); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamax(INDArray in, int... dimensions) { + NDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, dimensions)); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamax(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, keepDims, dimensions)); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamin(INDArray in, int... dimensions) { + NDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, dimensions)); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamin(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, keepDims, dimensions)); + } + + /** + * Is finite operation: elementwise isFinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isFinite(INDArray x) { + NDValidation.validateNumerical("isFinite", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(x)); + } + + /** + * Is infinite operation: elementwise isInfinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isInfinite(INDArray x) { + NDValidation.validateNumerical("isInfinite", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(x)); + } + + /** + * Is maximum operation: elementwise x == max(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isMax(INDArray x) { + NDValidation.validateNumerical("isMax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(x))[0]; + } + + /** + * Is Not a Number operation: elementwise isNaN(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isNaN(INDArray x) { + NDValidation.validateNumerical("isNaN", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(x)); + } + + /** + * Is the array non decreasing?
+ * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) + */ + public INDArray isNonDecreasing(INDArray x) { + NDValidation.validateNumerical("isNonDecreasing", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(x))[0]; + } + + /** + * Is the array strictly increasing?
+ * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + */ + public INDArray isStrictlyIncreasing(INDArray x) { + NDValidation.validateNumerical("isStrictlyIncreasing", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(x))[0]; + } + + /** + * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
+ * tensor along the specified dimensions.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("jaccardDistance", "x", x); + NDValidation.validateNumerical("jaccardDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions)); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray lastIndex(INDArray in, Condition condition, int... dimensions) { + NDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, dimensions)); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray lastIndex(INDArray in, Condition condition, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, keepDims, dimensions)); + } + + /** + * Element-wise logarithm function (base e - natural logarithm): out = log(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray log(INDArray x) { + NDValidation.validateNumerical("log", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(x)); + } + + /** + * Element-wise logarithm function (with specified base): out = log_{base}(x)
+ * + * @param x Input variable (NUMERIC type) + * @param base Logarithm base (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray log(INDArray x, INDArray base) { + NDValidation.validateNumerical("log", "x", x); + NDValidation.validateNumerical("log", "base", base); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(x, base)); + } + + /** + * Elementwise natural logarithm function: out = log_e (1 + x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray log1p(INDArray x) { + NDValidation.validateNumerical("log1p", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(x)); + } + + /** + * Log entropy reduction: log(-sum(x * log(x)))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray logEntropy(INDArray in, int... dimensions) { + NDValidation.validateNumerical("logEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(in, dimensions)); + } + + /** + * Log-sum-exp reduction (optionally along dimension).
+ * Computes log(sum(exp(x))
+ * + * @param input Input variable (NUMERIC type) + * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray logSumExp(INDArray input, int... dimensions) { + NDValidation.validateNumerical("logSumExp", "input", input); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(input, dimensions))[0]; + } + + /** + * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i]-y[i])
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("manhattanDistance", "x", x); + NDValidation.validateNumerical("manhattanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions)); + } + + /** + * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
+ * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix determinant variable (NUMERIC type) + */ + public INDArray matrixDeterminant(INDArray in) { + NDValidation.validateNumerical("matrixDeterminant", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(in))[0]; + } + + /** + * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
+ * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix inverse variable (NUMERIC type) + */ + public INDArray matrixInverse(INDArray in) { + NDValidation.validateNumerical("matrixInverse", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(in))[0]; + } + + /** + * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
+ * out = sum_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mergeAdd(INDArray[] inputs) { + NDValidation.validateNumerical("mergeAdd", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0]; + } + + /** + * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
+ * out = mean_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mergeAvg(INDArray[] inputs) { + NDValidation.validateNumerical("mergeAvg", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0]; + } + + /** + * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
+ * out = max_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mergeMax(INDArray[] inputs) { + NDValidation.validateNumerical("mergeMax", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0]; + } + + /** + * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * + * @param input Input to calculate moments for (NUMERIC type) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + * @return output Mean and variance variables (NUMERIC type) + */ + public INDArray moments(INDArray input, int... axes) { + NDValidation.validateNumerical("moments", "input", input); + Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes))[0]; + } + + /** + * Elementwise negative operation: out = -x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray neg(INDArray x) { + NDValidation.validateNumerical("neg", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(x)); + } + + /** + * Calculate the mean and variance from the sufficient statistics
+ * + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + * @return output Output variables: mean and population variance (NUMERIC type) + */ + public INDArray normalizeMoments(INDArray counts, INDArray means, INDArray variances, + double shift) { + NDValidation.validateNumerical("normalizeMoments", "counts", counts); + NDValidation.validateNumerical("normalizeMoments", "means", means); + NDValidation.validateNumerical("normalizeMoments", "variances", variances); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(counts, means, variances, shift))[0]; + } + + /** + * Boolean OR operation: elementwise (x != 0) || (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public INDArray or(INDArray x, INDArray y) { + NDValidation.validateBool("or", "x", x); + NDValidation.validateBool("or", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(x, y)); + } + + /** + * Element-wise power function: out = x^value
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray pow(INDArray x, double value) { + NDValidation.validateNumerical("pow", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Pow(x, value)); + } + + /** + * Element-wise (broadcastable) power function: out = x[i]^y[i]
+ * + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray pow(INDArray x, INDArray y) { + NDValidation.validateNumerical("pow", "x", x); + NDValidation.validateNumerical("pow", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(x, y))[0]; + } + + /** + * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray reciprocal(INDArray x) { + NDValidation.validateNumerical("reciprocal", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(x)); + } + + /** + * Element-wise round function: out = round(x).
+ * Rounds (up or down depending on value) to the nearest integer value.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray round(INDArray x) { + NDValidation.validateNumerical("round", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Round(x)); + } + + /** + * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rsqrt(INDArray x) { + NDValidation.validateNumerical("rsqrt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(x)); + } + + /** + * Set the diagonal value to the specified values
+ * If input is
+ * [ a, b, c]
+ * [ d, e, f]
+ * [ g, h, i]
+ * and diag = [ 1, 2, 3] then output is
+ * [ 1, b, c]
+ * [ d, 2, f]
+ * [ g, h, 3]
+ * + * @param in Input variable (NUMERIC type) + * @param diag Diagonal (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray setDiag(INDArray in, INDArray diag) { + NDValidation.validateNumerical("setDiag", "in", in); + NDValidation.validateNumerical("setDiag", "diag", diag); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(in, diag))[0]; + } + + /** + * Shannon Entropy reduction: -sum(x * log2(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray shannonEntropy(INDArray in, int... dimensions) { + NDValidation.validateNumerical("shannonEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(in, dimensions)); + } + + /** + * Element-wise sign (signum) function:
+ * out = -1 if in < 0
+ * out = 0 if in = 0
+ * out = 1 if in > 0
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sign(INDArray x) { + NDValidation.validateNumerical("sign", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(x)); + } + + /** + * Elementwise sine operation: out = sin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sin(INDArray x) { + NDValidation.validateNumerical("sin", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(x)); + } + + /** + * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sinh(INDArray x) { + NDValidation.validateNumerical("sinh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(x)); + } + + /** + * Element-wise square root function: out = sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sqrt(INDArray x) { + NDValidation.validateNumerical("sqrt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(x)); + } + + /** + * Element-wise square function: out = x^2
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray square(INDArray x) { + NDValidation.validateNumerical("square", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Square(x)); + } + + /** + * Standardize input variable along given axis
+ *


+ * out = (x - mean) / stdev
+ *


+ * with mean and stdev being calculated along the given dimension.
+ *


+ * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
+ *


    + *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • + *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • + *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *

+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray standardize(INDArray x, int... dimensions) { + NDValidation.validateNumerical("standardize", "x", x); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(x, dimensions))[0]; + } + + /** + * Elementwise step function:
+ * out(x) = 1 if x >= cutoff
+ * out(x) = 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray step(INDArray x, double value) { + NDValidation.validateNumerical("step", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Step(x, value)); + } + + /** + * Elementwise tangent operation: out = tan(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray tan(INDArray x) { + NDValidation.validateNumerical("tan", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(x)); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray tanh(INDArray x) { + NDValidation.validateNumerical("tanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x)); + } + + /** + * Matrix trace operation
+ * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
+ * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
+ * + * @param in Input variable (NUMERIC type) + * @return output Trace (NUMERIC type) + */ + public INDArray trace(INDArray in) { + NDValidation.validateNumerical("trace", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(in))[0]; + } + + /** + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public INDArray xor(INDArray x, INDArray y) { + NDValidation.validateBool("xor", "x", x); + NDValidation.validateBool("xor", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(x, y)); + } + + /** + * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
+ * + * @param input Input variable (NUMERIC type) + * @return output Reduced array of rank 0 (scalar) (NUMERIC type) + */ + public INDArray zeroFraction(INDArray input) { + NDValidation.validateNumerical("zeroFraction", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(input))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java new file mode 100644 index 000000000..815f22e5b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -0,0 +1,522 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDNN { + public NDNN() { + } + + /** + * Neural network batch normalization operation.
+ * For details, see https://arxiv.org/abs/1502.03167
+ * + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. + * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC + * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @return output variable for batch normalization (NUMERIC type) + */ + public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, + INDArray beta, double epsilon, int... axis) { + NDValidation.validateNumerical("batchNorm", "input", input); + NDValidation.validateNumerical("batchNorm", "mean", mean); + NDValidation.validateNumerical("batchNorm", "variance", variance); + NDValidation.validateNumerical("batchNorm", "gamma", gamma); + NDValidation.validateNumerical("batchNorm", "beta", beta); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis))[0]; + } + + /** + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
+ * + * @param input 4d input variable (NUMERIC type) + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs + * @return output Output variable, after applying bias add operation (NUMERIC type) + */ + public INDArray biasAdd(INDArray input, INDArray bias, boolean nchw) { + NDValidation.validateNumerical("biasAdd", "input", input); + NDValidation.validateNumerical("biasAdd", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw))[0]; + } + + /** + * This operation performs dot product attention on the given timeseries input with the given queries
+ * out = sum(similarity(k_i, q) * v_i)
+ *
+ * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ *
+ * Optionally with normalization step:
+ * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ *
+ * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
+ * be 3D but can have queryCount = 1
+ *
+ * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
+ * both.
+ *
+ * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
+ * output rank will depend on the input rank.
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] + * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], + * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public INDArray dotProductAttention(INDArray queries, INDArray keys, INDArray values, + INDArray mask, boolean scaled) { + NDValidation.validateNumerical("dotProductAttention", "queries", queries); + NDValidation.validateNumerical("dotProductAttention", "keys", keys); + NDValidation.validateNumerical("dotProductAttention", "values", values); + NDValidation.validateNumerical("dotProductAttention", "mask", mask); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled))[0]; + } + + /** + * Dropout operation
+ * + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @return output Output (NUMERIC type) + */ + public INDArray dropout(INDArray input, double inputRetainProbability) { + NDValidation.validateNumerical("dropout", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.DropOut(input, inputRetainProbability)); + } + + /** + * Element-wise exponential linear unit (ELU) function:
+ * out = x if x > 0
+ * out = a * (exp(x) - 1) if x <= 0
+ * with constant a = 1.0
+ *


+ * See: https://arxiv.org/abs/1511.07289
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray elu(INDArray x) { + NDValidation.validateNumerical("elu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x))[0]; + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the sigmoid approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray gelu(INDArray x) { + NDValidation.validateNumerical("gelu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(x)); + } + + /** + * Element-wise hard sigmoid function:
+ * out[i] = 0 if in[i] <= -2.5
+ * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
+ * out[i] = 1 if in[i] >= 2.5
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray hardSigmoid(INDArray x) { + NDValidation.validateNumerical("hardSigmoid", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(x)); + } + + /** + * Element-wise hard tanh function:
+ * out[i] = -1 if in[i] <= -1
+ * out[1] = in[i] if -1 < in[i] < 1
+ * out[i] = 1 if in[i] >= 1
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray hardTanh(INDArray x) { + NDValidation.validateNumerical("hardTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(x)); + } + + /** + * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray hardTanhDerivative(INDArray x) { + NDValidation.validateNumerical("hardTanhDerivative", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(x)); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray layerNorm(INDArray input, INDArray gain, INDArray bias, boolean channelsFirst, + int... dimensions) { + NDValidation.validateNumerical("layerNorm", "input", input); + NDValidation.validateNumerical("layerNorm", "gain", gain); + NDValidation.validateNumerical("layerNorm", "bias", bias); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions))[0]; + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray layerNorm(INDArray input, INDArray gain, boolean channelsFirst, + int... dimensions) { + NDValidation.validateNumerical("layerNorm", "input", input); + NDValidation.validateNumerical("layerNorm", "gain", gain); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, channelsFirst, dimensions))[0]; + } + + /** + * Element-wise leaky ReLU function:
+ * out = x if x >= 0.0
+ * out = alpha * x if x < cutoff
+ * Alpha value is most commonly set to 0.01
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray leakyRelu(INDArray x, INDArray alpha) { + NDValidation.validateNumerical("leakyRelu", "x", x); + NDValidation.validateNumerical("leakyRelu", "alpha", alpha); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha)); + } + + /** + * Leaky ReLU derivative: dOut/dIn given input.
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray leakyReluDerivative(INDArray x, INDArray alpha) { + NDValidation.validateNumerical("leakyReluDerivative", "x", x); + NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha)); + } + + /** + * Linear layer operation: out = mmul(in,w) + bias
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray linear(INDArray input, INDArray weights, INDArray bias) { + NDValidation.validateNumerical("linear", "input", input); + NDValidation.validateNumerical("linear", "weights", weights); + NDValidation.validateNumerical("linear", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias))[0]; + } + + /** + * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray logSigmoid(INDArray x) { + NDValidation.validateNumerical("logSigmoid", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(x)); + } + + /** + * Log softmax activation
+ * + * @param x (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray logSoftmax(INDArray x) { + NDValidation.validateNumerical("logSoftmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x))[0]; + } + + /** + * Log softmax activation
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply log softmax + * @return output Output - log(softmax(input)) (NUMERIC type) + */ + public INDArray logSoftmax(INDArray x, int dimension) { + NDValidation.validateNumerical("logSoftmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension))[0]; + } + + /** + * This performs multi-headed dot product attention on the given timeseries input
+ * out = concat(head_1, head_2, ..., head_n) * Wo
+ * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
+ *
+ * Optionally with normalization when calculating the attention for each head.
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
+ *
+ * This makes use of dot_product_attention OP support for rank 4 inputs.
+ * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] + * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, INDArray values, + INDArray Wq, INDArray Wk, INDArray Wv, INDArray Wo, INDArray mask, boolean scaled) { + NDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries); + NDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys); + NDValidation.validateNumerical("multiHeadDotProductAttention", "values", values); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); + NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled))[0]; + } + + /** + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
+ * out[i] = in[i] if in[i] >= 0
+ * out[i] = in[i] * alpha[i] otherwise
+ *
+ * sharedAxes allows you to share learnable parameters along axes.
+ * For example, if the input has shape [batchSize, channels, height, width]
+ * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
+ * alpha with shape [channels].
+ * + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) + * @return output Output (NUMERIC type) + */ + public INDArray prelu(INDArray input, INDArray alpha, int... sharedAxes) { + NDValidation.validateNumerical("prelu", "input", input); + NDValidation.validateNumerical("prelu", "alpha", alpha); + Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes))[0]; + } + + /** + * Element-wise rectified linear function with specified cutoff:
+ * out[i] = in[i] if in[i] >= cutoff
+ * out[i] = 0 otherwise
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @return output Output (NUMERIC type) + */ + public INDArray relu(INDArray x, double cutoff) { + NDValidation.validateNumerical("relu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(x, cutoff)); + } + + /** + * Element-wise "rectified linear 6" function with specified cutoff:
+ * out[i] = min(max(in, cutoff), 6)
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation. Usually 0 + * @return output Output (NUMERIC type) + */ + public INDArray relu6(INDArray x, double cutoff) { + NDValidation.validateNumerical("relu6", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Relu6(x, cutoff)); + } + + /** + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray reluLayer(INDArray input, INDArray weights, INDArray bias) { + NDValidation.validateNumerical("reluLayer", "input", input); + NDValidation.validateNumerical("reluLayer", "weights", weights); + NDValidation.validateNumerical("reluLayer", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias))[0]; + } + + /** + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
+ *
+ * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
+ * Uses default scale and alpha values.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray selu(INDArray x) { + NDValidation.validateNumerical("selu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(x)); + } + + /** + * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sigmoid(INDArray x) { + NDValidation.validateNumerical("sigmoid", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(x)); + } + + /** + * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
+ * + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @return output Output (gradient at input of sigmoid) (NUMERIC type) + */ + public INDArray sigmoidDerivative(INDArray x, INDArray wrt) { + NDValidation.validateNumerical("sigmoidDerivative", "x", x); + NDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt))[0]; + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply softmax + * @return output Output variable (NUMERIC type) + */ + public INDArray softmax(INDArray x, int dimension) { + NDValidation.validateNumerical("softmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0]; + } + + /** + * Softmax derivative function
+ * + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param dimension Softmax dimension + * @return output (NUMERIC type) + */ + public INDArray softmaxDerivative(INDArray x, INDArray wrt, int dimension) { + NDValidation.validateNumerical("softmaxDerivative", "x", x); + NDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(x, wrt, dimension))[0]; + } + + /** + * Element-wise softplus function: out = log(exp(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray softplus(INDArray x) { + NDValidation.validateNumerical("softplus", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(x)); + } + + /** + * Element-wise softsign function: out = x / (abs(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray softsign(INDArray x) { + NDValidation.validateNumerical("softsign", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(x)); + } + + /** + * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public INDArray softsignDerivative(INDArray x) { + NDValidation.validateNumerical("softsignDerivative", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(x)); + } + + /** + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
+ * See: https://arxiv.org/abs/1710.05941
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray swish(INDArray x) { + NDValidation.validateNumerical("swish", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java new file mode 100644 index 000000000..5737ced1f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java @@ -0,0 +1,138 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class NDRandom { + public NDRandom() { + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,
+ * with the specified probability. Array values will have value 1 with probability P and value 0 with probability
+ * 1-P.
+ * + * @param p Probability of value 1 + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray bernoulli(double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(p, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,
+ * with the specified number of trials and probability.
+ * + * @param nTrials Number of trials parameter for the binomial distribution + * @param p Probability of success for each trial + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray binomial(int nTrials, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(nTrials, p, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:
+ * P(x) = lambda * exp(-lambda * x)
+ * + * Inputs must satisfy the following constraints:
+ * Must be positive: lambda > 0
+ * + * @param lambda lambda parameter + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + */ + public INDArray[] exponential(double lambda, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(lambda > 0, "Must be positive"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,
+ * i.e., {@code log(x) ~ N(mean, stdev)}
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray logNormal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(mean, stddev, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev)
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray normal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(mean, stddev, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray normalTruncated(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(mean, stddev, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,
+ * U(min,max)
+ * + * @param min Minimum value + * @param max Maximum value. + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray uniform(double min, double max, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(min, max, datatype, shape)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index ed3b5a7cb..e10ffcddb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -1620,11 +1620,11 @@ public class SameDiffTests extends BaseNd4jTest { switch (i) { case 0: t = sd.math().isNonDecreasing(in1); - Nd4j.exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsNonDecreasing(ia, expOut)); break; case 1: t = sd.math().isStrictlyIncreasing(in1); - Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsStrictlyIncreasing(ia, expOut)); break; case 2: t = sd.isNumericTensor(in1); @@ -1650,7 +1650,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray ia = Nd4j.randn(minibatch, nOut); INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape()); - Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsStrictlyIncreasing(ia, expOut)); System.out.println(expOut); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 83145a048..6582d38db 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SameDiffOp; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java new file mode 100644 index 000000000..445f72342 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java @@ -0,0 +1,68 @@ +package org.nd4j.linalg.api; + +import org.junit.Test; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +public class TestNamespaces extends BaseNd4jTest { + + public TestNamespaces(Nd4jBackend backend) { + super(backend); + } + + @Test + public void testBitwiseSimple(){ + + INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); + INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); + + INDArray and = Nd4j.bitwise.and(x, y); + INDArray or = Nd4j.bitwise.or(x, y); + INDArray xor = Nd4j.bitwise.xor(x, y); + + System.out.println(and); + System.out.println(or); + System.out.println(xor); + + } + + @Test + public void testMathSimple(){ + INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1); + INDArray abs = Nd4j.math.abs(x); + System.out.println(x); + System.out.println(abs); + + + INDArray c1 = Nd4j.createFromArray(0, 2, 1); + INDArray c2 = Nd4j.createFromArray(1, 2, 1); + + INDArray cm = Nd4j.math.confusionMatrix(c1, c2, 3); + System.out.println(cm); + } + + @Test + public void testRandomSimple(){ + INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10); + System.out.println(normal); + INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10); + System.out.println(uniform); + } + + @Test + public void testNeuralNetworkSimple(){ + INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10)); + System.out.println(out); + INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1); + System.out.println(out2); + } + + @Override + public char ordering() { + return 'c'; + } + +}