Add ND4J namespaces (#83)

* Add NDValidation

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add bitwise namespace

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Math namespace op constructor fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Constructor fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add Math namespace

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update NDBitwise

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add random namespaces

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* NN namespace

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-30 18:39:32 +11:00 committed by GitHub
parent dc66a52bc7
commit 4fb9fa7748
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 2879 additions and 26 deletions

View File

@ -48,6 +48,10 @@ public class BiasAdd extends DynamicCustomOp {
this.nchw = nchw;
}
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, boolean nchw){
this(input, bias, null, nchw);
}
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
super(new INDArray[]{input, bias}, wrapOrNull(output));
bArguments.clear();

View File

@ -54,7 +54,12 @@ public class FirstIndex extends BaseIndexAccumulation {
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
this(x, condition, false, dimension);
}
public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) {
this(x, condition, Nd4j.EPS_THRESHOLD, dimension);
this.keepDims = keepDims;
}
public FirstIndex(INDArray x, @NonNull Condition condition, double eps, int... dimension) {

View File

@ -38,7 +38,12 @@ public class IAMax extends BaseIndexAccumulation {
public IAMax() {}
public IAMax(INDArray x, int... dimensions) {
this(x, false, dimensions);
}
public IAMax(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, dimensions);
this.keepDims = keepDims;
}
public IAMax(INDArray x, INDArray z, int... dimensions) {

View File

@ -41,6 +41,11 @@ public class IAMin extends BaseIndexAccumulation {
super(x, dimensions);
}
public IAMin(INDArray in, boolean keepDims, int... dimnesions){
super(in, null, dimnesions);
this.keepDims = keepDims;
}
public IAMin(INDArray x, INDArray z, int... dimensions) {
super(x, z, dimensions);
}

View File

@ -58,6 +58,11 @@ public class LastIndex extends BaseIndexAccumulation {
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
}
public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) {
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
this.keepDims = keepDim;
}
public LastIndex(INDArray x, @NonNull Condition condition, double eps, int... dimensions) {
super(x,null, dimensions);
this.condition = condition;

View File

@ -76,6 +76,15 @@ public class BatchNorm extends DynamicCustomOp {
addArgs();
}
public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){
super(wrapFilterNull(input, mean, variance, gamma, beta), null);
this.jaxis = axis;
this.applyBeta = beta != null;
this.applyGamma = gamma != null;
this.epsilon = epsilon;
addArgs();
}
public void addArgs() {
addIArgument(ArrayUtil.fromBoolean(applyGamma));
addIArgument(ArrayUtil.fromBoolean(applyBeta));

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -40,6 +41,12 @@ public class Moments extends DynamicCustomOp {
private int[] axes;
public Moments(@NonNull INDArray input, int... axes){
super(new INDArray[]{input}, null);
this.axes = axes;
addArgs();
}
public Moments(SameDiff sameDiff, SDVariable input) {
this(sameDiff, input, null);
}

View File

@ -47,6 +47,12 @@ public class NormalizeMoments extends DynamicCustomOp {
addArgs();
}
public NormalizeMoments(INDArray counts, INDArray means, INDArray variances, double shift) {
super(null, new INDArray[]{counts, means, variances}, null);
this.shift = shift;
addArgs();
}
public NormalizeMoments(INDArray counts, INDArray ssSum, INDArray ssSqSum, INDArray outMean, INDArray outVar) {
super(null, new INDArray[]{counts, ssSum, ssSqSum}, new INDArray[]{outMean, outVar},
new ArrayList<Double>(), new ArrayList<Integer>());

View File

@ -17,11 +17,13 @@
package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
@ -36,10 +38,13 @@ import java.util.List;
public class ZeroFraction extends DynamicCustomOp {
public ZeroFraction(SameDiff sameDiff, SDVariable input) {
super(null, sameDiff, new SDVariable[] {input}, false);
}
public ZeroFraction(@NonNull INDArray input){
super(new INDArray[]{input}, null);
}
@Override
public String opName() {
return "zero_fraction";

View File

@ -45,6 +45,10 @@ public class PRelu extends DynamicCustomOp {
addIArgument(sharedAxes);
}
public PRelu(@NonNull INDArray x, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
this(x, null, alpha, sharedAxes);
}
public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
super(new INDArray[]{x, alpha}, new INDArray[]{z});
this.sharedAxes = sharedAxes;

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import lombok.val;
import org.apache.commons.lang3.NotImplementedException;
import org.nd4j.autodiff.samediff.SDVariable;
@ -23,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -41,6 +43,35 @@ public class ConfusionMatrix extends DynamicCustomOp {
public ConfusionMatrix(){
}
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){
super(new INDArray[]{labels, predicted}, null);
this.outputType = dataType;
}
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){
this(labels, predicted, numClasses, DEFAULT_DTYPE);
}
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights) {
this(labels, predicted, weights, null);
}
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses) {
this(labels, predicted, weights, numClasses, DEFAULT_DTYPE);
}
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, Integer numClasses, @NonNull DataType dataType) {
this(labels, predicted, null, numClasses, dataType);
}
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses, @NonNull DataType dataType) {
super(wrapFilterNull(labels, predicted, weights), null);
this.outputType = dataType;
if(numClasses != null) {
addIArgument(numClasses);
}
}
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){
super(null, sameDiff, new SDVariable[]{labels, pred});
this.outputType = dataType;
@ -57,7 +88,9 @@ public class ConfusionMatrix extends DynamicCustomOp {
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){
super(null, sameDiff, new SDVariable[]{labels, pred, weights});
addIArgument(numClasses);
if(numClasses != null) {
addIArgument(numClasses);
}
}
@Override

View File

@ -44,13 +44,16 @@ public class Cross extends DynamicCustomOp {
public Cross() {
}
public Cross(SameDiff sameDiff, SDVariable[] args) {
super(null, sameDiff, args, false);
}
public Cross(INDArray a, INDArray b){
this(a,b,null);
}
public Cross(INDArray a, INDArray b, INDArray out){
super(null, new INDArray[]{a,b}, out == null ? null : new INDArray[]{out}, null, (int[])null);
super(null, new INDArray[]{a,b}, wrapOrNull(out), null, (int[])null);
}
@Override

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -44,8 +45,12 @@ public class Diag extends DynamicCustomOp {
public Diag() {
}
public Diag(INDArray[] inputs, INDArray[] outputs) {
super(null, inputs, outputs);
public Diag(@NonNull INDArray input) {
this(input, null);
}
public Diag(@NonNull INDArray input, @NonNull INDArray output){
super(null, new INDArray[]{input}, wrapOrNull(output));
}
public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {

View File

@ -51,6 +51,10 @@ public class DiagPart extends DynamicCustomOp {
super(null, sameDiff, args, inPlace);
}
public DiagPart(INDArray in){
this(in, null);
}
public DiagPart(INDArray in, INDArray out){
super(null, in, out, null, null);
}

View File

@ -16,11 +16,14 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.shade.guava.base.Preconditions;
import java.util.Collections;
import java.util.List;
@ -55,6 +58,23 @@ public class Eye extends DynamicCustomOp {
public Eye() {
}
public Eye(@NonNull INDArray rows){
this(rows.getInt(0));
Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar");
}
public Eye(@NonNull INDArray rows, @NonNull INDArray columns){
this(rows.getInt(0), columns.getInt(0));
Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar");
Preconditions.checkArgument(columns.isScalar(), "Columns INDArray must be a scalar");
}
public Eye(int rows){
this.numRows = rows;
this.numCols = rows;
addArgs();
}
public Eye(SameDiff sameDiff, SDVariable numRows){
super(null, sameDiff, new SDVariable[] {numRows}, false);
}
@ -66,10 +86,7 @@ public class Eye extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {numRows, numCols, batch_shape}, false);
}
public Eye(SameDiff sameDiff, int numRows) {
super(null, sameDiff, new SDVariable[] {}, false);
this.numRows = numRows;
this.numCols = numRows;
addArgs();
this(sameDiff, numRows, numRows);
}
public Eye(SameDiff sameDiff, int numRows, int numCols) {
@ -77,13 +94,25 @@ public class Eye extends DynamicCustomOp {
}
public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType) {
super(null, sameDiff, new SDVariable[] {}, false);
this(sameDiff, numRows, numCols, dataType, null);
}
public Eye(int numRows, int numCols, DataType dataType, int[] batchDimension) {
this.numRows = numRows;
this.numCols = numCols;
this.batchDimension = batchDimension;
this.dataType = dataType;
addArgs();
}
public Eye(int numRows, int numCols) {
this(numRows, numCols, DEFAULT_DTYPE);
}
public Eye(int numRows, int numCols, DataType dataType) {
this(numRows, numCols, dataType, null);
}
public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType, int[] batchDimension) {
super(null, sameDiff, new SDVariable[] {}, false);
this.numRows = numRows;

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
@ -36,6 +37,10 @@ import java.util.Map;
@Slf4j
public class MergeAvg extends DynamicCustomOp {
public MergeAvg(INDArray... inputs){
super(inputs, null);
}
public MergeAvg(SameDiff sameDiff, SDVariable... inputs) {
super(null, sameDiff, inputs);
}

View File

@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -40,6 +41,10 @@ public class MergeMax extends DynamicCustomOp {
super(null, sameDiff, inputs);
}
public MergeMax(INDArray... inputs){
super(inputs, null);
}
public MergeMax(){ }
@Override

View File

@ -17,9 +17,11 @@
package org.nd4j.linalg.api.ops.impl.transforms;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import java.util.Collections;
@ -37,7 +39,10 @@ public class ReluLayer extends XwPlusB {
public ReluLayer(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) {
super(sameDiff, input, weights, bias);
}
public ReluLayer(@NonNull INDArray input, @NonNull INDArray weights, @NonNull INDArray bias){
super(new INDArray[]{input, weights, bias}, null);
}
@Override

View File

@ -49,8 +49,12 @@ public class ClipByNorm extends DynamicCustomOp {
addTArgument(clipValue);
}
public ClipByNorm(INDArray in, double clipValue, int... dimensions){
this(in, null, clipValue, dimensions);
}
public ClipByNorm(INDArray in, INDArray out, double clipValue, int... dimensions){
super(null, new INDArray[]{in}, (out == null ? null : new INDArray[]{out}), Collections.singletonList(clipValue), dimensions);
super(null, new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions);
}
@Override

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.clip;
import lombok.NonNull;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -38,11 +39,10 @@ public class ClipByValue extends DynamicCustomOp {
private double clipValueMin;
private double clipValueMax;
public ClipByValue(INDArray[] inputs, INDArray[] outputs, double clipValueMin, double clipValueMax, boolean inPlace) {
super(null, inputs, outputs);
public ClipByValue(@NonNull INDArray input, double clipValueMin, double clipValueMax) {
super(null, new INDArray[]{input}, null);
this.clipValueMin = clipValueMin;
this.clipValueMax = clipValueMax;
this.inplaceCall = inPlace;
addTArgument(clipValueMin, clipValueMax);
}

View File

@ -41,13 +41,22 @@ public class ATan2 extends BaseDynamicTransformOp {
super(sameDiff, new SDVariable[] {y, x} ,false);
}
/**
* Note that the order of x and y match {@link java.lang.Math#atan2(double, double)},
* and are reversed when compared to OldATan2.
* See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)}
*/
public ATan2(INDArray x, INDArray y) {
this(x,y,null);
}
/**
* Note that the order of x and y match {@link java.lang.Math#atan2(double, double)},
* and are reversed when compared to OldATan2.
* See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)}
*/
public ATan2(INDArray x, INDArray y, INDArray z) {
super(new INDArray[]{x, y}, new INDArray[]{ z });
super(new INDArray[]{x, y}, wrapOrNull(z));
}
public ATan2() {}

View File

@ -17,10 +17,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
@ -49,6 +51,18 @@ public class DotProductAttention extends DynamicCustomOp {
addIArgument(withWeights ? 1 : 0);
}
public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled){
this(queries, keys, values, mask, scaled, false);
}
public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled, boolean withWeights){
super(wrapFilterNull(queries, keys, values, mask), null);
this.scaled = scaled;
this.withWeights = withWeights;
addIArgument(scaled ? 1 : 0);
addIArgument(withWeights ? 1 : 0);
}
@Override
public String opName() {
return "dot_product_attention";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -40,8 +41,12 @@ public class IsNonDecreasing extends DynamicCustomOp {
super(null, sameDiff, args, inPlace);
}
public IsNonDecreasing(INDArray[] inputs, INDArray[] outputs) {
super(null, inputs, outputs);
public IsNonDecreasing(@NonNull INDArray input){
this(input, null);
}
public IsNonDecreasing(@NonNull INDArray input, INDArray output) {
super(null, new INDArray[]{input}, wrapOrNull(output));
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -38,8 +39,12 @@ public class IsStrictlyIncreasing extends DynamicCustomOp {
super(null, sameDiff, args, inPlace);
}
public IsStrictlyIncreasing( INDArray[] inputs, INDArray[] outputs) {
super(null, inputs, outputs);
public IsStrictlyIncreasing(@NonNull INDArray input){
this(input, null);
}
public IsStrictlyIncreasing(@NonNull INDArray input, INDArray output) {
super(null, new INDArray[]{input}, wrapOrNull(output));
}

View File

@ -62,6 +62,10 @@ public class LayerNorm extends DynamicCustomOp {
setDimensions(dimensions);
}
public LayerNorm(@NonNull INDArray input, @NonNull INDArray gain, boolean channelsFirst, int... dimensions) {
this(input, gain, null, channelsFirst, dimensions);
}
public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) {
this(input, gain, null, result, channelsFirst, dimensions);
}

View File

@ -52,6 +52,11 @@ public class LogSoftMax extends DynamicCustomOp {
this(x, x);
}
public LogSoftMax(INDArray x, int dimension) {
this(x, null);
this.dimension = dimension;
}
public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) {
this(sameDiff, i_v);
this.dimension = dimension;

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
@ -39,6 +41,10 @@ public class MatrixDeterminant extends DynamicCustomOp {
//
}
public MatrixDeterminant(@NonNull INDArray input){
super(new INDArray[]{input}, null);
}
public MatrixDeterminant(SameDiff sameDiff, SDVariable in, boolean inPlace) {
super(null, sameDiff, new SDVariable[]{in}, inPlace);
}

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
@ -36,6 +38,10 @@ public class MatrixInverse extends DynamicCustomOp {
//
}
public MatrixInverse(@NonNull INDArray input){
super(new INDArray[]{input}, null);
}
public MatrixInverse(SameDiff sameDiff, SDVariable in, boolean inPlace) {
super(null, sameDiff, new SDVariable[]{in}, inPlace);
}

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
@ -32,6 +34,10 @@ public class MatrixSetDiag extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{in, diag}, inPlace);
}
public MatrixSetDiag(@NonNull INDArray in, @NonNull INDArray diag){
super(new INDArray[]{in, diag}, null);
}
public MatrixSetDiag(){ }
@Override

View File

@ -17,10 +17,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
@ -54,6 +56,22 @@ public class MultiHeadDotProductAttention extends DynamicCustomOp {
addIArgument(withWeights ? 1 : 0);
}
public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values,
@NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo,
INDArray mask, boolean scaled) {
this(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false);
}
public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values,
@NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo,
INDArray mask, boolean scaled, boolean withWeights) {
super(wrapFilterNull(queries, keys, values, Wq, Wk, Wv, Wo, mask), null);
this.scaled = scaled;
this.withWeights = withWeights;
addIArgument(scaled ? 1 : 0);
addIArgument(withWeights ? 1 : 0);
}
@Override
public String opName() {
return "multi_head_dot_product_attention";

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
@ -39,6 +41,10 @@ public class Pow extends DynamicCustomOp {
public Pow(){ }
public Pow(@NonNull INDArray x, @NonNull INDArray y){
super(new INDArray[]{x,y}, null);
}
@Override
public String opName(){
return "Pow";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -69,8 +70,12 @@ public class SoftMax extends BaseDynamicTransformOp {
addIArgument(dimension);
}
public SoftMax(@NonNull INDArray input, int dimension){
this(input, null, dimension);
}
public SoftMax(INDArray input, INDArray result, int dimension){
super(new INDArray[]{input}, new INDArray[]{result});
super(new INDArray[]{input}, wrapOrNull(result));
this.dimension = dimension;
addIArgument(dimension);
}

View File

@ -34,8 +34,12 @@ public class Standardize extends DynamicCustomOp {
setDimensions(dimensions);
}
public Standardize(INDArray input, int... dimensions){
this(input, null, dimensions);
}
public Standardize(INDArray input, INDArray result, int... dimensions){
super("standardize", new INDArray[]{input}, new INDArray[]{result});
super("standardize", new INDArray[]{input},wrapOrNull(result));
setDimensions(dimensions);
}

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
@ -37,6 +39,10 @@ public class Trace extends DynamicCustomOp {
super(null, sd, new SDVariable[]{in});
}
public Trace(@NonNull INDArray in){
super(wrapOrNull(in), null);
}
public Trace(){ }
@Override

View File

@ -46,7 +46,14 @@ public class XwPlusB extends DynamicCustomOp {
public XwPlusB(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) {
super(null, sameDiff, new SDVariable[] {input, weights, bias}, false);
}
public XwPlusB(INDArray input, INDArray weights, INDArray bias) {
super(new INDArray[] {input, weights, bias}, null);
}
public XwPlusB(INDArray[] inputs, INDArray output){
super(inputs, wrapOrNull(output));
}
@Override

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -35,6 +36,10 @@ public class SigmoidDerivative extends DynamicCustomOp {
super(sameDiff, new SDVariable[]{i_v1, i_v2});
}
public SigmoidDerivative(@NonNull INDArray x, @NonNull INDArray y) {
this(x, y, null);
}
public SigmoidDerivative(INDArray x, INDArray y, INDArray z) {
super(null, new INDArray[]{x,y}, new INDArray[]{z}, null, (int[])null);
}

View File

@ -42,6 +42,10 @@ public class SoftmaxBp extends DynamicCustomOp {
addIArgument(dimension);
}
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, Integer dimension){
this(input, grad, null, dimension);
}
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){
super(new INDArray[]{input, grad}, wrapOrNull(output));
if(dimension != null)

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -41,6 +42,10 @@ public class MergeAddOp extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public MergeAddOp(@NonNull INDArray... inputs){
this(inputs, null);
}
public MergeAddOp(INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}

View File

@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;
@ -47,6 +48,10 @@ public class RandomExponential extends DynamicCustomOp {
addTArgument(lambda);
}
public RandomExponential(double lambda, DataType datatype, long... shape){
this(Nd4j.createFromArray(shape), Nd4j.createUninitialized(datatype, shape), lambda);
}
public RandomExponential(INDArray shape,INDArray out, double lambda){
super(null, new INDArray[]{shape}, new INDArray[]{out}, Collections.singletonList(lambda), (List<Integer>)null);
this.lambda = lambda;

View File

@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.LinkedHashMap;
@ -49,6 +50,10 @@ public class BernoulliDistribution extends BaseRandomOp {
super();
}
public BernoulliDistribution(double p, DataType datatype, long... shape){
this(Nd4j.createUninitialized(datatype, shape), p);
}
/**
* This op fills Z with bernoulli trial results, so 0, or 1, depending by common probability
* @param z

View File

@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.LinkedHashMap;
@ -46,6 +47,10 @@ public class BinomialDistribution extends BaseRandomOp {
this.extraArgs = new Object[] {(double) this.trials, this.probability};
}
public BinomialDistribution(int trials, double probability, DataType dt, long[] shape){
this(Nd4j.createUninitialized(dt, shape), trials, probability);
}
public BinomialDistribution() {
super();
}

View File

@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.LinkedHashMap;
@ -50,6 +51,10 @@ public class GaussianDistribution extends BaseRandomOp {
super();
}
public GaussianDistribution(double mean, double stddev, DataType datatype, long... shape){
this(Nd4j.createUninitialized(datatype, shape), mean, stddev);
}
/**
* This op fills Z with random values within stddev..mean..stddev boundaries
* @param z

View File

@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.LinkedHashMap;
@ -50,6 +51,10 @@ public class LogNormalDistribution extends BaseRandomOp {
this.extraArgs = new Object[] {this.mean, this.stddev};
}
public LogNormalDistribution(double mean, double stddev, DataType datatype, long... shape){
this(Nd4j.createUninitialized(datatype, shape), mean, stddev);
}
/**
* This op fills Z with random values within stddev..mean..stddev boundaries
* @param z

View File

@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;
@ -48,6 +49,10 @@ public class TruncatedNormalDistribution extends BaseRandomOp {
this.extraArgs = new Object[] {this.mean, this.stddev};
}
public TruncatedNormalDistribution(double mean, double stddev, DataType datatype, long... shape){
this(Nd4j.createUninitialized(datatype, shape), mean, stddev);
}
/**
* This op fills Z with random values within stddev..mean..stddev boundaries
* @param z

View File

@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;
@ -46,6 +47,10 @@ public class UniformDistribution extends BaseRandomOp {
this.extraArgs = new Object[] {this.from, this.to};
}
public UniformDistribution(double min, double max, DataType datatype, long... shape){
this(Nd4j.createUninitialized(datatype, shape), min, max);
}
/**
* This op fills Z with random values within from...to boundaries
* @param z

View File

@ -0,0 +1,236 @@
/* *****************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.factory;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
public class NDValidation {
private NDValidation() {
}
/**
* Validate that the operation is being applied on a numerical INDArray (not boolean or utf8).
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
*
* @param opName Operation name to print in the exception
* @param v Variable to perform operation on
*/
public static void validateNumerical(String opName, INDArray v) {
if (v == null)
return;
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-numerical data type " + v.dataType());
}
/**
* Validate that the operation is being applied on numerical INDArrays (not boolean or utf8).
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
*
* @param opName Operation name to print in the exception
* @param v Variable to perform operation on
*/
public static void validateNumerical(String opName, INDArray[] v) {
if (v == null)
return;
for (int i = 0; i < v.length; i++) {
if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8)
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input array " + i + " with non-numerical data type " + v[i].dataType());
}
}
/**
* Validate that the operation is being applied on a numerical INDArray (not boolean or utf8).
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
*
* @param opName Operation name to print in the exception
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateNumerical(String opName, String inputName, INDArray v) {
if (v == null)
return;
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type;" +
" got array with non-integer data type " + v.dataType());
}
/**
* Validate that the operation is being applied on numerical INDArrays (not boolean or utf8).
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
*
* @param opName Operation name to print in the exception
* @param v Variable to perform operation on
*/
public static void validateNumerical(String opName, String inputName, INDArray[] v) {
if (v == null)
return;
for (int i = 0; i < v.length; i++) {
if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8)
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input \"" + inputName + "\" array " + i + " with non-numerical data type " + v[i].dataType());
}
}
/**
* Validate that the operation is being applied on numerical INDArrays (not boolean or utf8).
* Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays
*
* @param opName Operation name to print in the exception
* @param v1 Variable to validate datatype for (input to operation)
* @param v2 Variable to validate datatype for (input to operation)
*/
public static void validateNumerical(String opName, INDArray v1, INDArray v2) {
if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8)
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on arrays if one or both variables" +
" are non-numerical: got " + v1.dataType() + " and " + v2.dataType());
}
/**
* Validate that the operation is being applied on an integer type INDArray
*
* @param opName Operation name to print in the exception
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateInteger(String opName, INDArray v) {
if (v == null)
return;
if (!v.dataType().isIntType())
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-integer data type " + v.dataType());
}
/**
* Validate that the operation is being applied on an integer type INDArray
*
* @param opName Operation name to print in the exception
* @param inputName Name of the input to the op to validate
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateInteger(String opName, String inputName, INDArray v) {
if (v == null)
return;
if (!v.dataType().isIntType())
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" +
" type; got array with non-integer data type " + v.dataType());
}
/**
* Validate that the operation is being applied on an floating point type INDArray
*
* @param opName Operation name to print in the exception
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateFloatingPoint(String opName, INDArray v) {
if (v == null)
return;
if (!v.dataType().isFPType())
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-floating point data type " + v.dataType());
}
/**
* Validate that the operation is being applied on a floating point type INDArray
*
* @param opName Operation name to print in the exception
* @param inputName Name of the input to the op to validate
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateFloatingPoint(String opName, String inputName, INDArray v) {
if (v == null)
return;
if (!v.dataType().isFPType())
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName +
"\" must be an floating point type; got array with non-floating point data type " + v.dataType());
}
/**
* Validate that the operation is being applied on a boolean type INDArray
*
* @param opName Operation name to print in the exception
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateBool(String opName, INDArray v) {
if (v == null)
return;
if (v.dataType() != DataType.BOOL)
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-boolean point data type " + v.dataType());
}
/**
* Validate that the operation is being applied on a boolean type INDArray
*
* @param opName Operation name to print in the exception
* @param inputName Name of the input to the op to validate
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateBool(String opName, String inputName, INDArray v) {
if (v == null)
return;
if (v.dataType() != DataType.BOOL)
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName +
"\" must be an boolean variable; got array with non-boolean data type " + v.dataType());
}
/**
* Validate that the operation is being applied on boolean INDArrays
*
* @param opName Operation name to print in the exception
* @param v1 Variable to validate datatype for (input to operation)
* @param v2 Variable to validate datatype for (input to operation)
*/
public static void validateBool(String opName, INDArray v1, INDArray v2) {
if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL)
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on array if one or both variables are non-boolean: "
+ v1.dataType() + " and " + v2.dataType());
}
/**
* Validate that the operation is being applied on array with the exact same datatypes (which may optionally be
* restricted to numerical INDArrays only (not boolean or utf8))
*
* @param opName Operation name to print in the exception
* @param numericalOnly If true, the variables must all be the same type, and must be numerical (not boolean/utf8)
* @param vars Variable to perform operation on
*/
public static void validateSameType(String opName, boolean numericalOnly, INDArray... vars) {
if (vars.length == 0)
return;
if (vars.length == 1) {
if (numericalOnly) {
validateNumerical(opName, vars[0]);
}
} else {
DataType first = vars[0].dataType();
if (numericalOnly)
validateNumerical(opName, vars[0]);
for (int i = 1; i < vars.length; i++) {
if (first != vars[i].dataType()) {
DataType[] dtypes = new DataType[vars.length];
for (int j = 0; j < vars.length; j++) {
dtypes[j] = vars[j].dataType();
}
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to arrays with different datatypes:" +
" Got arrays with datatypes " + Arrays.toString(dtypes));
}
}
}
}
public static boolean isSameType(INDArray x, INDArray y) {
return x.dataType() == y.dataType();
}
}

View File

@ -16,6 +16,10 @@
package org.nd4j.linalg.factory;
import org.nd4j.linalg.factory.ops.NDBitwise;
import org.nd4j.linalg.factory.ops.NDMath;
import org.nd4j.linalg.factory.ops.NDNN;
import org.nd4j.linalg.factory.ops.NDRandom;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs;
import lombok.NonNull;
@ -114,6 +118,51 @@ import java.util.logging.Logger;
*/
public class Nd4j {
/**
* Bitwise namespace - operations related to bitwise manipulation of arrays
*/
public static final NDBitwise bitwise = new NDBitwise();
/**
* Math namespace - general mathematical operations
*/
public static final NDMath math = new NDMath();
/**
* Random namespace - (pseudo) random number generation methods
*/
public static final NDRandom random = new NDRandom();
/**
* Neural network namespace - operations related to neural networks
*/
public static final NDNN nn = new NDNN();
/**
* Bitwise namespace - operations related to bitwise manipulation of arrays
*/
public static NDBitwise bitwise() {
return bitwise;
}
/**
* Math namespace - general mathematical operations
*/
public static NDMath math() {
return math;
}
/**
* Random namespace - (pseudo) random number generation methods
*/
public static NDRandom random() {
return random;
}
/**
* Neural network namespace - operations related to neural networks
*/
public static NDNN nn() {
return nn;
}
private final static String DATA_BUFFER_OPS = "databufferfactory";
private final static String CONVOLUTION_OPS = "convops";
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
@ -2638,7 +2687,7 @@ public class Nd4j {
INDArray ret;
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
ret = Nd4j.create(x.dataType(), x.length(), x.length());
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
Nd4j.getExecutioner().execAndReturn(new Diag(x, ret));
} else {
ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1)));
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));

View File

@ -0,0 +1,211 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;
public class NDBitwise {
public NDBitwise() {
}
/**
* Bitwise AND operation. Supports broadcasting.<br>
*
* Inputs must satisfy the following constraints: <br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
*
* @param x First input array (INT type)
* @param y Second input array (INT type)
* @return output Bitwise AND array (INT type)
*/
public INDArray and(INDArray x, INDArray y) {
NDValidation.validateInteger("and", "x", x);
NDValidation.validateInteger("and", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(x, y))[0];
}
/**
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
*
* @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public INDArray bitRotl(INDArray x, INDArray shift) {
NDValidation.validateInteger("bitRotl", "x", x);
NDValidation.validateInteger("bitRotl", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0];
}
/**
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
*
* @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public INDArray bitRotr(INDArray x, INDArray shift) {
NDValidation.validateInteger("bitRotr", "x", x);
NDValidation.validateInteger("bitRotr", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0];
}
/**
* Shift integer bits to the left, i.e. var << 4<br>
*
* @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public INDArray bitShift(INDArray x, INDArray shift) {
NDValidation.validateInteger("bitShift", "x", x);
NDValidation.validateInteger("bitShift", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0];
}
/**
* Shift integer bits to the right, i.e. var >> 4<br>
*
* @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public INDArray bitShiftRight(INDArray x, INDArray shift) {
NDValidation.validateInteger("bitShiftRight", "x", x);
NDValidation.validateInteger("bitShiftRight", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0];
}
/**
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
*
* Inputs must satisfy the following constraints: <br>
* Must be same types: isSameType(x, y)<br>
*
* @param x First input array. (INT type)
* @param y Second input array. (INT type)
* @return output bitwise Hamming distance (INT type)
*/
public INDArray bitsHammingDistance(INDArray x, INDArray y) {
NDValidation.validateInteger("bitsHammingDistance", "x", x);
NDValidation.validateInteger("bitsHammingDistance", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(x, y))[0];
}
/**
* Bitwise left shift operation. Supports broadcasting.<br>
*
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type)
*/
public INDArray leftShift(INDArray x, INDArray y) {
NDValidation.validateInteger("leftShift", "x", x);
NDValidation.validateInteger("leftShift", "y", y);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, y))[0];
}
/**
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
*
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise cyclic shifted input x (INT type)
*/
public INDArray leftShiftCyclic(INDArray x, INDArray y) {
NDValidation.validateInteger("leftShiftCyclic", "x", x);
NDValidation.validateInteger("leftShiftCyclic", "y", y);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, y))[0];
}
/**
* Bitwise OR operation. Supports broadcasting.<br>
*
* Inputs must satisfy the following constraints: <br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
*
* @param x First input array (INT type)
* @param y First input array (INT type)
* @return output Bitwise OR array (INT type)
*/
public INDArray or(INDArray x, INDArray y) {
NDValidation.validateInteger("or", "x", x);
NDValidation.validateInteger("or", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(x, y))[0];
}
/**
* Bitwise right shift operation. Supports broadcasting. <br>
*
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type)
*/
public INDArray rightShift(INDArray x, INDArray y) {
NDValidation.validateInteger("rightShift", "x", x);
NDValidation.validateInteger("rightShift", "y", y);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, y))[0];
}
/**
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
*
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise cyclic shifted input x (INT type)
*/
public INDArray rightShiftCyclic(INDArray x, INDArray y) {
NDValidation.validateInteger("rightShiftCyclic", "x", x);
NDValidation.validateInteger("rightShiftCyclic", "y", y);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, y))[0];
}
/**
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
*
* Inputs must satisfy the following constraints: <br>
* Must be same types: isSameType(x, y)<br>
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
*
* @param x First input array (INT type)
* @param y First input array (INT type)
* @return output Bitwise XOR array (INT type)
*/
public INDArray xor(INDArray x, INDArray y) {
NDValidation.validateInteger("xor", "x", x);
NDValidation.validateInteger("xor", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(x, y))[0];
}
}

View File

@ -0,0 +1,522 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;
public class NDNN {
public NDNN() {
}
/**
* Neural network batch normalization operation.<br>
* For details, see <a href="https://arxiv.org/abs/1502.03167">https://arxiv.org/abs/1502.03167</a><br>
*
* @param input Input variable. (NUMERIC type)
* @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type)
* @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type)
* @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type)
* @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type)
* @param epsilon Epsilon constant for numerical stability (to avoid division by 0)
* @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations.
* For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC
* For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1))
* @return output variable for batch normalization (NUMERIC type)
*/
public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma,
INDArray beta, double epsilon, int... axis) {
NDValidation.validateNumerical("batchNorm", "input", input);
NDValidation.validateNumerical("batchNorm", "mean", mean);
NDValidation.validateNumerical("batchNorm", "variance", variance);
NDValidation.validateNumerical("batchNorm", "gamma", gamma);
NDValidation.validateNumerical("batchNorm", "beta", beta);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis))[0];
}
/**
* Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector<br>
*
* @param input 4d input variable (NUMERIC type)
* @param bias 1d bias (NUMERIC type)
* @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels].
* Unused for 2d inputs
* @return output Output variable, after applying bias add operation (NUMERIC type)
*/
public INDArray biasAdd(INDArray input, INDArray bias, boolean nchw) {
NDValidation.validateNumerical("biasAdd", "input", input);
NDValidation.validateNumerical("biasAdd", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw))[0];
}
/**
* This operation performs dot product attention on the given timeseries input with the given queries<br>
* out = sum(similarity(k_i, q) * v_i)<br>
* <br>
* similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q<br>
* <br>
* Optionally with normalization step:<br>
* similarity(k, q) = softmax(k * q / sqrt(size(q))<br>
* <br>
* See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)<br>
* <br>
* Note: This supports multiple queries at once, if only one query is available the queries vector still has to<br>
* be 3D but can have queryCount = 1<br>
* <br>
* Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for<br>
* both.<br>
* <br>
* Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The<br>
* output rank will depend on the input rank.<br>
*
* @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount]
* or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type)
* @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps]
* or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type)
* @param values input 3D array "values" of shape [batchSize, featureValues, timesteps]
* or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type)
* @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type)
* @param scaled normalization, false -> do not apply normalization, true -> apply normalization
* @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount],
* (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type)
*/
public INDArray dotProductAttention(INDArray queries, INDArray keys, INDArray values,
INDArray mask, boolean scaled) {
NDValidation.validateNumerical("dotProductAttention", "queries", queries);
NDValidation.validateNumerical("dotProductAttention", "keys", keys);
NDValidation.validateNumerical("dotProductAttention", "values", values);
NDValidation.validateNumerical("dotProductAttention", "mask", mask);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled))[0];
}
/**
* Dropout operation<br>
*
* @param input Input array (NUMERIC type)
* @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p)
* @return output Output (NUMERIC type)
*/
public INDArray dropout(INDArray input, double inputRetainProbability) {
NDValidation.validateNumerical("dropout", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.DropOut(input, inputRetainProbability));
}
/**
* Element-wise exponential linear unit (ELU) function:<br>
* out = x if x > 0<br>
* out = a * (exp(x) - 1) if x <= 0<br>
* with constant a = 1.0<br>
* <p><br>
* See: <a href="https://arxiv.org/abs/1511.07289">https://arxiv.org/abs/1511.07289</a><br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray elu(INDArray x) {
NDValidation.validateNumerical("elu", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x))[0];
}
/**
* GELU activation function - Gaussian Error Linear Units<br>
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a><br>
* This method uses the sigmoid approximation<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray gelu(INDArray x) {
NDValidation.validateNumerical("gelu", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(x));
}
/**
* Element-wise hard sigmoid function:<br>
* out[i] = 0 if in[i] <= -2.5<br>
* out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5<br>
* out[i] = 1 if in[i] >= 2.5<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray hardSigmoid(INDArray x) {
NDValidation.validateNumerical("hardSigmoid", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(x));
}
/**
* Element-wise hard tanh function:<br>
* out[i] = -1 if in[i] <= -1<br>
* out[1] = in[i] if -1 < in[i] < 1<br>
* out[i] = 1 if in[i] >= 1<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray hardTanh(INDArray x) {
NDValidation.validateNumerical("hardTanh", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(x));
}
/**
* Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray hardTanhDerivative(INDArray x) {
NDValidation.validateNumerical("hardTanhDerivative", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(x));
}
/**
* Apply Layer Normalization<br>
* <br>
* y = gain * standardize(x) + bias<br>
*
* @param input Input variable (NUMERIC type)
* @param gain Gain (NUMERIC type)
* @param bias Bias (NUMERIC type)
* @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data
* @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1))
* @return output Output variable (NUMERIC type)
*/
public INDArray layerNorm(INDArray input, INDArray gain, INDArray bias, boolean channelsFirst,
int... dimensions) {
NDValidation.validateNumerical("layerNorm", "input", input);
NDValidation.validateNumerical("layerNorm", "gain", gain);
NDValidation.validateNumerical("layerNorm", "bias", bias);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions))[0];
}
/**
* Apply Layer Normalization<br>
* <br>
* y = gain * standardize(x) + bias<br>
*
* @param input Input variable (NUMERIC type)
* @param gain Gain (NUMERIC type)
* @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data
* @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1))
* @return output Output variable (NUMERIC type)
*/
public INDArray layerNorm(INDArray input, INDArray gain, boolean channelsFirst,
int... dimensions) {
NDValidation.validateNumerical("layerNorm", "input", input);
NDValidation.validateNumerical("layerNorm", "gain", gain);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, channelsFirst, dimensions))[0];
}
/**
* Element-wise leaky ReLU function:<br>
* out = x if x >= 0.0<br>
* out = alpha * x if x < cutoff<br>
* Alpha value is most commonly set to 0.01<br>
*
* @param x Input variable (NUMERIC type)
* @param alpha Cutoff - commonly 0.01 (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray leakyRelu(INDArray x, INDArray alpha) {
NDValidation.validateNumerical("leakyRelu", "x", x);
NDValidation.validateNumerical("leakyRelu", "alpha", alpha);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha));
}
/**
* Leaky ReLU derivative: dOut/dIn given input.<br>
*
* @param x Input variable (NUMERIC type)
* @param alpha Cutoff - commonly 0.01 (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray leakyReluDerivative(INDArray x, INDArray alpha) {
NDValidation.validateNumerical("leakyReluDerivative", "x", x);
NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha));
}
/**
* Linear layer operation: out = mmul(in,w) + bias<br>
* Note that bias array is optional<br>
*
* @param input Input data (NUMERIC type)
* @param weights Weights variable, shape [nIn, nOut] (NUMERIC type)
* @param bias Optional bias variable (may be null) (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray linear(INDArray input, INDArray weights, INDArray bias) {
NDValidation.validateNumerical("linear", "input", input);
NDValidation.validateNumerical("linear", "weights", weights);
NDValidation.validateNumerical("linear", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias))[0];
}
/**
* Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray logSigmoid(INDArray x) {
NDValidation.validateNumerical("logSigmoid", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(x));
}
/**
* Log softmax activation<br>
*
* @param x (NUMERIC type)
* @return output (NUMERIC type)
*/
public INDArray logSoftmax(INDArray x) {
NDValidation.validateNumerical("logSoftmax", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x))[0];
}
/**
* Log softmax activation<br>
*
* @param x Input (NUMERIC type)
* @param dimension Dimension along which to apply log softmax
* @return output Output - log(softmax(input)) (NUMERIC type)
*/
public INDArray logSoftmax(INDArray x, int dimension) {
NDValidation.validateNumerical("logSoftmax", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension))[0];
}
/**
* This performs multi-headed dot product attention on the given timeseries input<br>
* out = concat(head_1, head_2, ..., head_n) * Wo<br>
* head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)<br>
* <br>
* Optionally with normalization when calculating the attention for each head.<br>
* <br>
* See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")<br>
* <br>
* This makes use of dot_product_attention OP support for rank 4 inputs.<br>
* see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)<br>
*
* @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type)
* @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type)
* @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type)
* @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type)
* @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type)
* @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type)
* @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type)
* @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type)
* @param scaled normalization, false -> do not apply normalization, true -> apply normalization
* @return output Attention result arrays of shape [batchSize, outSize, queryCount]
* (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type)
*/
public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, INDArray values,
INDArray Wq, INDArray Wk, INDArray Wv, INDArray Wo, INDArray mask, boolean scaled) {
NDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries);
NDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys);
NDValidation.validateNumerical("multiHeadDotProductAttention", "values", values);
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq);
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk);
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv);
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo);
NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled))[0];
}
/**
* PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br>
* out[i] = in[i] if in[i] >= 0<br>
* out[i] = in[i] * alpha[i] otherwise<br>
* <br>
* sharedAxes allows you to share learnable parameters along axes.<br>
* For example, if the input has shape [batchSize, channels, height, width]<br>
* and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an<br>
* alpha with shape [channels].<br>
*
* @param input Input data (NUMERIC type)
* @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type)
* @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1))
* @return output Output (NUMERIC type)
*/
public INDArray prelu(INDArray input, INDArray alpha, int... sharedAxes) {
NDValidation.validateNumerical("prelu", "input", input);
NDValidation.validateNumerical("prelu", "alpha", alpha);
Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes))[0];
}
/**
* Element-wise rectified linear function with specified cutoff:<br>
* out[i] = in[i] if in[i] >= cutoff<br>
* out[i] = 0 otherwise<br>
*
* @param x Input (NUMERIC type)
* @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0
* @return output Output (NUMERIC type)
*/
public INDArray relu(INDArray x, double cutoff) {
NDValidation.validateNumerical("relu", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(x, cutoff));
}
/**
* Element-wise "rectified linear 6" function with specified cutoff:<br>
* out[i] = min(max(in, cutoff), 6)<br>
*
* @param x Input (NUMERIC type)
* @param cutoff Cutoff value for ReLU operation. Usually 0
* @return output Output (NUMERIC type)
*/
public INDArray relu6(INDArray x, double cutoff) {
NDValidation.validateNumerical("relu6", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Relu6(x, cutoff));
}
/**
* ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)<br>
* Note that bias array is optional<br>
*
* @param input Input data (NUMERIC type)
* @param weights Weights variable (NUMERIC type)
* @param bias Optional bias variable (may be null) (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray reluLayer(INDArray input, INDArray weights, INDArray bias) {
NDValidation.validateNumerical("reluLayer", "input", input);
NDValidation.validateNumerical("reluLayer", "weights", weights);
NDValidation.validateNumerical("reluLayer", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias))[0];
}
/**
* Element-wise SeLU function - Scaled exponential Lineal Unit: see <a href="https://arxiv.org/abs/1706.02515">Self-Normalizing Neural Networks</a><br>
* <br>
* out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0<br>
* Uses default scale and alpha values.<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray selu(INDArray x) {
NDValidation.validateNumerical("selu", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(x));
}
/**
* Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray sigmoid(INDArray x) {
NDValidation.validateNumerical("sigmoid", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(x));
}
/**
* Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut<br>
*
* @param x Input Variable (NUMERIC type)
* @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type)
* @return output Output (gradient at input of sigmoid) (NUMERIC type)
*/
public INDArray sigmoidDerivative(INDArray x, INDArray wrt) {
NDValidation.validateNumerical("sigmoidDerivative", "x", x);
NDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt))[0];
}
/**
* Softmax activation, along the specified dimension<br>
*
* @param x Input (NUMERIC type)
* @param dimension Dimension along which to apply softmax
* @return output Output variable (NUMERIC type)
*/
public INDArray softmax(INDArray x, int dimension) {
NDValidation.validateNumerical("softmax", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0];
}
/**
* Softmax derivative function<br>
*
* @param x Softmax input (NUMERIC type)
* @param wrt Gradient at output, dL/dx (NUMERIC type)
* @param dimension Softmax dimension
* @return output (NUMERIC type)
*/
public INDArray softmaxDerivative(INDArray x, INDArray wrt, int dimension) {
NDValidation.validateNumerical("softmaxDerivative", "x", x);
NDValidation.validateNumerical("softmaxDerivative", "wrt", wrt);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(x, wrt, dimension))[0];
}
/**
* Element-wise softplus function: out = log(exp(x) + 1)<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray softplus(INDArray x) {
NDValidation.validateNumerical("softplus", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(x));
}
/**
* Element-wise softsign function: out = x / (abs(x) + 1)<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray softsign(INDArray x) {
NDValidation.validateNumerical("softsign", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(x));
}
/**
* Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output (NUMERIC type)
*/
public INDArray softsignDerivative(INDArray x) {
NDValidation.validateNumerical("softsignDerivative", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(x));
}
/**
* Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0<br>
* See: <a href="https://arxiv.org/abs/1710.05941">https://arxiv.org/abs/1710.05941</a><br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray swish(INDArray x) {
NDValidation.validateNumerical("swish", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x));
}
}

View File

@ -0,0 +1,138 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class NDRandom {
public NDRandom() {
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,<br>
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability<br>
* 1-P.<br>
*
* @param p Probability of value 1
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public INDArray bernoulli(double p, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(p, datatype, shape));
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,<br>
* with the specified number of trials and probability.<br>
*
* @param nTrials Number of trials parameter for the binomial distribution
* @param p Probability of success for each trial
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public INDArray binomial(int nTrials, double p, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(nTrials, p, datatype, shape));
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:<br>
* P(x) = lambda * exp(-lambda * x)<br>
*
* Inputs must satisfy the following constraints: <br>
* Must be positive: lambda > 0<br>
*
* @param lambda lambda parameter
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
*/
public INDArray[] exponential(double lambda, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
Preconditions.checkArgument(lambda > 0, "Must be positive");
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape));
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,<br>
* i.e., {@code log(x) ~ N(mean, stdev)}<br>
*
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public INDArray logNormal(double mean, double stddev, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(mean, stddev, datatype, shape));
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
* N(mean, stdev)<br>
*
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public INDArray normal(double mean, double stddev, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(mean, stddev, datatype, shape));
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
* N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
*
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public INDArray normalTruncated(double mean, double stddev, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(mean, stddev, datatype, shape));
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,<br>
* U(min,max)<br>
*
* @param min Minimum value
* @param max Maximum value.
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public INDArray uniform(double min, double max, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(min, max, datatype, shape));
}
}

View File

@ -1620,11 +1620,11 @@ public class SameDiffTests extends BaseNd4jTest {
switch (i) {
case 0:
t = sd.math().isNonDecreasing(in1);
Nd4j.exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut}));
Nd4j.exec(new IsNonDecreasing(ia, expOut));
break;
case 1:
t = sd.math().isStrictlyIncreasing(in1);
Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut}));
Nd4j.exec(new IsStrictlyIncreasing(ia, expOut));
break;
case 2:
t = sd.isNumericTensor(in1);
@ -1650,7 +1650,7 @@ public class SameDiffTests extends BaseNd4jTest {
INDArray ia = Nd4j.randn(minibatch, nOut);
INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape());
Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut}));
Nd4j.exec(new IsStrictlyIncreasing(ia, expOut));
System.out.println(expOut);
}

View File

@ -31,6 +31,7 @@ import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;

View File

@ -0,0 +1,68 @@
package org.nd4j.linalg.api;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public class TestNamespaces extends BaseNd4jTest {
public TestNamespaces(Nd4jBackend backend) {
super(backend);
}
@Test
public void testBitwiseSimple(){
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
INDArray and = Nd4j.bitwise.and(x, y);
INDArray or = Nd4j.bitwise.or(x, y);
INDArray xor = Nd4j.bitwise.xor(x, y);
System.out.println(and);
System.out.println(or);
System.out.println(xor);
}
@Test
public void testMathSimple(){
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
INDArray abs = Nd4j.math.abs(x);
System.out.println(x);
System.out.println(abs);
INDArray c1 = Nd4j.createFromArray(0, 2, 1);
INDArray c2 = Nd4j.createFromArray(1, 2, 1);
INDArray cm = Nd4j.math.confusionMatrix(c1, c2, 3);
System.out.println(cm);
}
@Test
public void testRandomSimple(){
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
System.out.println(normal);
INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10);
System.out.println(uniform);
}
@Test
public void testNeuralNetworkSimple(){
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
System.out.println(out);
INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1);
System.out.println(out2);
}
@Override
public char ordering() {
return 'c';
}
}