Add ND4J namespaces (#83)
* Add NDValidation Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add bitwise namespace Signed-off-by: AlexDBlack <blacka101@gmail.com> * Math namespace op constructor fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Constructor fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add Math namespace Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update NDBitwise Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add random namespaces Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update Signed-off-by: AlexDBlack <blacka101@gmail.com> * NN namespace Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
dc66a52bc7
commit
4fb9fa7748
|
@ -48,6 +48,10 @@ public class BiasAdd extends DynamicCustomOp {
|
||||||
this.nchw = nchw;
|
this.nchw = nchw;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, boolean nchw){
|
||||||
|
this(input, bias, null, nchw);
|
||||||
|
}
|
||||||
|
|
||||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
||||||
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||||
bArguments.clear();
|
bArguments.clear();
|
||||||
|
|
|
@ -54,7 +54,12 @@ public class FirstIndex extends BaseIndexAccumulation {
|
||||||
|
|
||||||
|
|
||||||
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
|
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
|
||||||
|
this(x, condition, false, dimension);
|
||||||
|
}
|
||||||
|
|
||||||
|
public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) {
|
||||||
this(x, condition, Nd4j.EPS_THRESHOLD, dimension);
|
this(x, condition, Nd4j.EPS_THRESHOLD, dimension);
|
||||||
|
this.keepDims = keepDims;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FirstIndex(INDArray x, @NonNull Condition condition, double eps, int... dimension) {
|
public FirstIndex(INDArray x, @NonNull Condition condition, double eps, int... dimension) {
|
||||||
|
|
|
@ -38,7 +38,12 @@ public class IAMax extends BaseIndexAccumulation {
|
||||||
public IAMax() {}
|
public IAMax() {}
|
||||||
|
|
||||||
public IAMax(INDArray x, int... dimensions) {
|
public IAMax(INDArray x, int... dimensions) {
|
||||||
|
this(x, false, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public IAMax(INDArray x, boolean keepDims, int... dimensions) {
|
||||||
this(x, null, dimensions);
|
this(x, null, dimensions);
|
||||||
|
this.keepDims = keepDims;
|
||||||
}
|
}
|
||||||
|
|
||||||
public IAMax(INDArray x, INDArray z, int... dimensions) {
|
public IAMax(INDArray x, INDArray z, int... dimensions) {
|
||||||
|
|
|
@ -41,6 +41,11 @@ public class IAMin extends BaseIndexAccumulation {
|
||||||
super(x, dimensions);
|
super(x, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public IAMin(INDArray in, boolean keepDims, int... dimnesions){
|
||||||
|
super(in, null, dimnesions);
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
}
|
||||||
|
|
||||||
public IAMin(INDArray x, INDArray z, int... dimensions) {
|
public IAMin(INDArray x, INDArray z, int... dimensions) {
|
||||||
super(x, z, dimensions);
|
super(x, z, dimensions);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,11 @@ public class LastIndex extends BaseIndexAccumulation {
|
||||||
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
|
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) {
|
||||||
|
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
|
||||||
|
this.keepDims = keepDim;
|
||||||
|
}
|
||||||
|
|
||||||
public LastIndex(INDArray x, @NonNull Condition condition, double eps, int... dimensions) {
|
public LastIndex(INDArray x, @NonNull Condition condition, double eps, int... dimensions) {
|
||||||
super(x,null, dimensions);
|
super(x,null, dimensions);
|
||||||
this.condition = condition;
|
this.condition = condition;
|
||||||
|
|
|
@ -76,6 +76,15 @@ public class BatchNorm extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){
|
||||||
|
super(wrapFilterNull(input, mean, variance, gamma, beta), null);
|
||||||
|
this.jaxis = axis;
|
||||||
|
this.applyBeta = beta != null;
|
||||||
|
this.applyGamma = gamma != null;
|
||||||
|
this.epsilon = epsilon;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public void addArgs() {
|
public void addArgs() {
|
||||||
addIArgument(ArrayUtil.fromBoolean(applyGamma));
|
addIArgument(ArrayUtil.fromBoolean(applyGamma));
|
||||||
addIArgument(ArrayUtil.fromBoolean(applyBeta));
|
addIArgument(ArrayUtil.fromBoolean(applyBeta));
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.reduce;
|
package org.nd4j.linalg.api.ops.impl.reduce;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -40,6 +41,12 @@ public class Moments extends DynamicCustomOp {
|
||||||
|
|
||||||
private int[] axes;
|
private int[] axes;
|
||||||
|
|
||||||
|
public Moments(@NonNull INDArray input, int... axes){
|
||||||
|
super(new INDArray[]{input}, null);
|
||||||
|
this.axes = axes;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public Moments(SameDiff sameDiff, SDVariable input) {
|
public Moments(SameDiff sameDiff, SDVariable input) {
|
||||||
this(sameDiff, input, null);
|
this(sameDiff, input, null);
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,12 @@ public class NormalizeMoments extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public NormalizeMoments(INDArray counts, INDArray means, INDArray variances, double shift) {
|
||||||
|
super(null, new INDArray[]{counts, means, variances}, null);
|
||||||
|
this.shift = shift;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public NormalizeMoments(INDArray counts, INDArray ssSum, INDArray ssSqSum, INDArray outMean, INDArray outVar) {
|
public NormalizeMoments(INDArray counts, INDArray ssSum, INDArray ssSqSum, INDArray outMean, INDArray outVar) {
|
||||||
super(null, new INDArray[]{counts, ssSum, ssSqSum}, new INDArray[]{outMean, outVar},
|
super(null, new INDArray[]{counts, ssSum, ssSqSum}, new INDArray[]{outMean, outVar},
|
||||||
new ArrayList<Double>(), new ArrayList<Integer>());
|
new ArrayList<Double>(), new ArrayList<Integer>());
|
||||||
|
|
|
@ -17,11 +17,13 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.reduce;
|
package org.nd4j.linalg.api.ops.impl.reduce;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -36,10 +38,13 @@ import java.util.List;
|
||||||
public class ZeroFraction extends DynamicCustomOp {
|
public class ZeroFraction extends DynamicCustomOp {
|
||||||
|
|
||||||
public ZeroFraction(SameDiff sameDiff, SDVariable input) {
|
public ZeroFraction(SameDiff sameDiff, SDVariable input) {
|
||||||
|
|
||||||
super(null, sameDiff, new SDVariable[] {input}, false);
|
super(null, sameDiff, new SDVariable[] {input}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ZeroFraction(@NonNull INDArray input){
|
||||||
|
super(new INDArray[]{input}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "zero_fraction";
|
return "zero_fraction";
|
||||||
|
|
|
@ -45,6 +45,10 @@ public class PRelu extends DynamicCustomOp {
|
||||||
addIArgument(sharedAxes);
|
addIArgument(sharedAxes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public PRelu(@NonNull INDArray x, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
|
||||||
|
this(x, null, alpha, sharedAxes);
|
||||||
|
}
|
||||||
|
|
||||||
public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
|
public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
|
||||||
super(new INDArray[]{x, alpha}, new INDArray[]{z});
|
super(new INDArray[]{x, alpha}, new INDArray[]{z});
|
||||||
this.sharedAxes = sharedAxes;
|
this.sharedAxes = sharedAxes;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.NotImplementedException;
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -23,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -41,6 +43,35 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
||||||
public ConfusionMatrix(){
|
public ConfusionMatrix(){
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){
|
||||||
|
super(new INDArray[]{labels, predicted}, null);
|
||||||
|
this.outputType = dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){
|
||||||
|
this(labels, predicted, numClasses, DEFAULT_DTYPE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights) {
|
||||||
|
this(labels, predicted, weights, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses) {
|
||||||
|
this(labels, predicted, weights, numClasses, DEFAULT_DTYPE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, Integer numClasses, @NonNull DataType dataType) {
|
||||||
|
this(labels, predicted, null, numClasses, dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses, @NonNull DataType dataType) {
|
||||||
|
super(wrapFilterNull(labels, predicted, weights), null);
|
||||||
|
this.outputType = dataType;
|
||||||
|
if(numClasses != null) {
|
||||||
|
addIArgument(numClasses);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){
|
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){
|
||||||
super(null, sameDiff, new SDVariable[]{labels, pred});
|
super(null, sameDiff, new SDVariable[]{labels, pred});
|
||||||
this.outputType = dataType;
|
this.outputType = dataType;
|
||||||
|
@ -57,8 +88,10 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
||||||
|
|
||||||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){
|
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){
|
||||||
super(null, sameDiff, new SDVariable[]{labels, pred, weights});
|
super(null, sameDiff, new SDVariable[]{labels, pred, weights});
|
||||||
|
if(numClasses != null) {
|
||||||
addIArgument(numClasses);
|
addIArgument(numClasses);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
|
|
@ -44,13 +44,16 @@ public class Cross extends DynamicCustomOp {
|
||||||
public Cross() {
|
public Cross() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public Cross(SameDiff sameDiff, SDVariable[] args) {
|
public Cross(SameDiff sameDiff, SDVariable[] args) {
|
||||||
super(null, sameDiff, args, false);
|
super(null, sameDiff, args, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Cross(INDArray a, INDArray b){
|
||||||
|
this(a,b,null);
|
||||||
|
}
|
||||||
|
|
||||||
public Cross(INDArray a, INDArray b, INDArray out){
|
public Cross(INDArray a, INDArray b, INDArray out){
|
||||||
super(null, new INDArray[]{a,b}, out == null ? null : new INDArray[]{out}, null, (int[])null);
|
super(null, new INDArray[]{a,b}, wrapOrNull(out), null, (int[])null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -44,8 +45,12 @@ public class Diag extends DynamicCustomOp {
|
||||||
public Diag() {
|
public Diag() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Diag(INDArray[] inputs, INDArray[] outputs) {
|
public Diag(@NonNull INDArray input) {
|
||||||
super(null, inputs, outputs);
|
this(input, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Diag(@NonNull INDArray input, @NonNull INDArray output){
|
||||||
|
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
||||||
|
|
|
@ -51,6 +51,10 @@ public class DiagPart extends DynamicCustomOp {
|
||||||
super(null, sameDiff, args, inPlace);
|
super(null, sameDiff, args, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DiagPart(INDArray in){
|
||||||
|
this(in, null);
|
||||||
|
}
|
||||||
|
|
||||||
public DiagPart(INDArray in, INDArray out){
|
public DiagPart(INDArray in, INDArray out){
|
||||||
super(null, in, out, null, null);
|
super(null, in, out, null, null);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,11 +16,14 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -55,6 +58,23 @@ public class Eye extends DynamicCustomOp {
|
||||||
public Eye() {
|
public Eye() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Eye(@NonNull INDArray rows){
|
||||||
|
this(rows.getInt(0));
|
||||||
|
Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar");
|
||||||
|
}
|
||||||
|
|
||||||
|
public Eye(@NonNull INDArray rows, @NonNull INDArray columns){
|
||||||
|
this(rows.getInt(0), columns.getInt(0));
|
||||||
|
Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar");
|
||||||
|
Preconditions.checkArgument(columns.isScalar(), "Columns INDArray must be a scalar");
|
||||||
|
}
|
||||||
|
|
||||||
|
public Eye(int rows){
|
||||||
|
this.numRows = rows;
|
||||||
|
this.numCols = rows;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public Eye(SameDiff sameDiff, SDVariable numRows){
|
public Eye(SameDiff sameDiff, SDVariable numRows){
|
||||||
super(null, sameDiff, new SDVariable[] {numRows}, false);
|
super(null, sameDiff, new SDVariable[] {numRows}, false);
|
||||||
}
|
}
|
||||||
|
@ -66,10 +86,7 @@ public class Eye extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {numRows, numCols, batch_shape}, false);
|
super(null, sameDiff, new SDVariable[] {numRows, numCols, batch_shape}, false);
|
||||||
}
|
}
|
||||||
public Eye(SameDiff sameDiff, int numRows) {
|
public Eye(SameDiff sameDiff, int numRows) {
|
||||||
super(null, sameDiff, new SDVariable[] {}, false);
|
this(sameDiff, numRows, numRows);
|
||||||
this.numRows = numRows;
|
|
||||||
this.numCols = numRows;
|
|
||||||
addArgs();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Eye(SameDiff sameDiff, int numRows, int numCols) {
|
public Eye(SameDiff sameDiff, int numRows, int numCols) {
|
||||||
|
@ -77,13 +94,25 @@ public class Eye extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType) {
|
public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType) {
|
||||||
super(null, sameDiff, new SDVariable[] {}, false);
|
this(sameDiff, numRows, numCols, dataType, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Eye(int numRows, int numCols, DataType dataType, int[] batchDimension) {
|
||||||
this.numRows = numRows;
|
this.numRows = numRows;
|
||||||
this.numCols = numCols;
|
this.numCols = numCols;
|
||||||
|
this.batchDimension = batchDimension;
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Eye(int numRows, int numCols) {
|
||||||
|
this(numRows, numCols, DEFAULT_DTYPE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Eye(int numRows, int numCols, DataType dataType) {
|
||||||
|
this(numRows, numCols, dataType, null);
|
||||||
|
}
|
||||||
|
|
||||||
public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType, int[] batchDimension) {
|
public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType, int[] batchDimension) {
|
||||||
super(null, sameDiff, new SDVariable[] {}, false);
|
super(null, sameDiff, new SDVariable[] {}, false);
|
||||||
this.numRows = numRows;
|
this.numRows = numRows;
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -36,6 +37,10 @@ import java.util.Map;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MergeAvg extends DynamicCustomOp {
|
public class MergeAvg extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public MergeAvg(INDArray... inputs){
|
||||||
|
super(inputs, null);
|
||||||
|
}
|
||||||
|
|
||||||
public MergeAvg(SameDiff sameDiff, SDVariable... inputs) {
|
public MergeAvg(SameDiff sameDiff, SDVariable... inputs) {
|
||||||
super(null, sameDiff, inputs);
|
super(null, sameDiff, inputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -40,6 +41,10 @@ public class MergeMax extends DynamicCustomOp {
|
||||||
super(null, sameDiff, inputs);
|
super(null, sameDiff, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MergeMax(INDArray... inputs){
|
||||||
|
super(inputs, null);
|
||||||
|
}
|
||||||
|
|
||||||
public MergeMax(){ }
|
public MergeMax(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -17,9 +17,11 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms;
|
package org.nd4j.linalg.api.ops.impl.transforms;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -37,7 +39,10 @@ public class ReluLayer extends XwPlusB {
|
||||||
|
|
||||||
public ReluLayer(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) {
|
public ReluLayer(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) {
|
||||||
super(sameDiff, input, weights, bias);
|
super(sameDiff, input, weights, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ReluLayer(@NonNull INDArray input, @NonNull INDArray weights, @NonNull INDArray bias){
|
||||||
|
super(new INDArray[]{input, weights, bias}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -49,8 +49,12 @@ public class ClipByNorm extends DynamicCustomOp {
|
||||||
addTArgument(clipValue);
|
addTArgument(clipValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ClipByNorm(INDArray in, double clipValue, int... dimensions){
|
||||||
|
this(in, null, clipValue, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
public ClipByNorm(INDArray in, INDArray out, double clipValue, int... dimensions){
|
public ClipByNorm(INDArray in, INDArray out, double clipValue, int... dimensions){
|
||||||
super(null, new INDArray[]{in}, (out == null ? null : new INDArray[]{out}), Collections.singletonList(clipValue), dimensions);
|
super(null, new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -38,11 +39,10 @@ public class ClipByValue extends DynamicCustomOp {
|
||||||
private double clipValueMin;
|
private double clipValueMin;
|
||||||
private double clipValueMax;
|
private double clipValueMax;
|
||||||
|
|
||||||
public ClipByValue(INDArray[] inputs, INDArray[] outputs, double clipValueMin, double clipValueMax, boolean inPlace) {
|
public ClipByValue(@NonNull INDArray input, double clipValueMin, double clipValueMax) {
|
||||||
super(null, inputs, outputs);
|
super(null, new INDArray[]{input}, null);
|
||||||
this.clipValueMin = clipValueMin;
|
this.clipValueMin = clipValueMin;
|
||||||
this.clipValueMax = clipValueMax;
|
this.clipValueMax = clipValueMax;
|
||||||
this.inplaceCall = inPlace;
|
|
||||||
addTArgument(clipValueMin, clipValueMax);
|
addTArgument(clipValueMin, clipValueMax);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,13 +41,22 @@ public class ATan2 extends BaseDynamicTransformOp {
|
||||||
super(sameDiff, new SDVariable[] {y, x} ,false);
|
super(sameDiff, new SDVariable[] {y, x} ,false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note that the order of x and y match {@link java.lang.Math#atan2(double, double)},
|
||||||
|
* and are reversed when compared to OldATan2.
|
||||||
|
* See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)}
|
||||||
|
*/
|
||||||
|
public ATan2(INDArray x, INDArray y) {
|
||||||
|
this(x,y,null);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Note that the order of x and y match {@link java.lang.Math#atan2(double, double)},
|
* Note that the order of x and y match {@link java.lang.Math#atan2(double, double)},
|
||||||
* and are reversed when compared to OldATan2.
|
* and are reversed when compared to OldATan2.
|
||||||
* See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)}
|
* See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)}
|
||||||
*/
|
*/
|
||||||
public ATan2(INDArray x, INDArray y, INDArray z) {
|
public ATan2(INDArray x, INDArray y, INDArray z) {
|
||||||
super(new INDArray[]{x, y}, new INDArray[]{ z });
|
super(new INDArray[]{x, y}, wrapOrNull(z));
|
||||||
}
|
}
|
||||||
|
|
||||||
public ATan2() {}
|
public ATan2() {}
|
||||||
|
|
|
@ -17,10 +17,12 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -49,6 +51,18 @@ public class DotProductAttention extends DynamicCustomOp {
|
||||||
addIArgument(withWeights ? 1 : 0);
|
addIArgument(withWeights ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled){
|
||||||
|
this(queries, keys, values, mask, scaled, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled, boolean withWeights){
|
||||||
|
super(wrapFilterNull(queries, keys, values, mask), null);
|
||||||
|
this.scaled = scaled;
|
||||||
|
this.withWeights = withWeights;
|
||||||
|
addIArgument(scaled ? 1 : 0);
|
||||||
|
addIArgument(withWeights ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "dot_product_attention";
|
return "dot_product_attention";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -40,8 +41,12 @@ public class IsNonDecreasing extends DynamicCustomOp {
|
||||||
super(null, sameDiff, args, inPlace);
|
super(null, sameDiff, args, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IsNonDecreasing(INDArray[] inputs, INDArray[] outputs) {
|
public IsNonDecreasing(@NonNull INDArray input){
|
||||||
super(null, inputs, outputs);
|
this(input, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public IsNonDecreasing(@NonNull INDArray input, INDArray output) {
|
||||||
|
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -38,8 +39,12 @@ public class IsStrictlyIncreasing extends DynamicCustomOp {
|
||||||
super(null, sameDiff, args, inPlace);
|
super(null, sameDiff, args, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IsStrictlyIncreasing( INDArray[] inputs, INDArray[] outputs) {
|
public IsStrictlyIncreasing(@NonNull INDArray input){
|
||||||
super(null, inputs, outputs);
|
this(input, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public IsStrictlyIncreasing(@NonNull INDArray input, INDArray output) {
|
||||||
|
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,10 @@ public class LayerNorm extends DynamicCustomOp {
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LayerNorm(@NonNull INDArray input, @NonNull INDArray gain, boolean channelsFirst, int... dimensions) {
|
||||||
|
this(input, gain, null, channelsFirst, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) {
|
public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) {
|
||||||
this(input, gain, null, result, channelsFirst, dimensions);
|
this(input, gain, null, result, channelsFirst, dimensions);
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,6 +52,11 @@ public class LogSoftMax extends DynamicCustomOp {
|
||||||
this(x, x);
|
this(x, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LogSoftMax(INDArray x, int dimension) {
|
||||||
|
this(x, null);
|
||||||
|
this.dimension = dimension;
|
||||||
|
}
|
||||||
|
|
||||||
public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) {
|
public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) {
|
||||||
this(sameDiff, i_v);
|
this(sameDiff, i_v);
|
||||||
this.dimension = dimension;
|
this.dimension = dimension;
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -39,6 +41,10 @@ public class MatrixDeterminant extends DynamicCustomOp {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MatrixDeterminant(@NonNull INDArray input){
|
||||||
|
super(new INDArray[]{input}, null);
|
||||||
|
}
|
||||||
|
|
||||||
public MatrixDeterminant(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
public MatrixDeterminant(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
||||||
super(null, sameDiff, new SDVariable[]{in}, inPlace);
|
super(null, sameDiff, new SDVariable[]{in}, inPlace);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -36,6 +38,10 @@ public class MatrixInverse extends DynamicCustomOp {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MatrixInverse(@NonNull INDArray input){
|
||||||
|
super(new INDArray[]{input}, null);
|
||||||
|
}
|
||||||
|
|
||||||
public MatrixInverse(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
public MatrixInverse(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
||||||
super(null, sameDiff, new SDVariable[]{in}, inPlace);
|
super(null, sameDiff, new SDVariable[]{in}, inPlace);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -32,6 +34,10 @@ public class MatrixSetDiag extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{in, diag}, inPlace);
|
super(null, sameDiff, new SDVariable[]{in, diag}, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MatrixSetDiag(@NonNull INDArray in, @NonNull INDArray diag){
|
||||||
|
super(new INDArray[]{in, diag}, null);
|
||||||
|
}
|
||||||
|
|
||||||
public MatrixSetDiag(){ }
|
public MatrixSetDiag(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -17,10 +17,12 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -54,6 +56,22 @@ public class MultiHeadDotProductAttention extends DynamicCustomOp {
|
||||||
addIArgument(withWeights ? 1 : 0);
|
addIArgument(withWeights ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values,
|
||||||
|
@NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo,
|
||||||
|
INDArray mask, boolean scaled) {
|
||||||
|
this(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values,
|
||||||
|
@NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo,
|
||||||
|
INDArray mask, boolean scaled, boolean withWeights) {
|
||||||
|
super(wrapFilterNull(queries, keys, values, Wq, Wk, Wv, Wo, mask), null);
|
||||||
|
this.scaled = scaled;
|
||||||
|
this.withWeights = withWeights;
|
||||||
|
addIArgument(scaled ? 1 : 0);
|
||||||
|
addIArgument(withWeights ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "multi_head_dot_product_attention";
|
return "multi_head_dot_product_attention";
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -39,6 +41,10 @@ public class Pow extends DynamicCustomOp {
|
||||||
|
|
||||||
public Pow(){ }
|
public Pow(){ }
|
||||||
|
|
||||||
|
public Pow(@NonNull INDArray x, @NonNull INDArray y){
|
||||||
|
super(new INDArray[]{x,y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "Pow";
|
return "Pow";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -69,8 +70,12 @@ public class SoftMax extends BaseDynamicTransformOp {
|
||||||
addIArgument(dimension);
|
addIArgument(dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SoftMax(@NonNull INDArray input, int dimension){
|
||||||
|
this(input, null, dimension);
|
||||||
|
}
|
||||||
|
|
||||||
public SoftMax(INDArray input, INDArray result, int dimension){
|
public SoftMax(INDArray input, INDArray result, int dimension){
|
||||||
super(new INDArray[]{input}, new INDArray[]{result});
|
super(new INDArray[]{input}, wrapOrNull(result));
|
||||||
this.dimension = dimension;
|
this.dimension = dimension;
|
||||||
addIArgument(dimension);
|
addIArgument(dimension);
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,8 +34,12 @@ public class Standardize extends DynamicCustomOp {
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Standardize(INDArray input, int... dimensions){
|
||||||
|
this(input, null, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
public Standardize(INDArray input, INDArray result, int... dimensions){
|
public Standardize(INDArray input, INDArray result, int... dimensions){
|
||||||
super("standardize", new INDArray[]{input}, new INDArray[]{result});
|
super("standardize", new INDArray[]{input},wrapOrNull(result));
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -37,6 +39,10 @@ public class Trace extends DynamicCustomOp {
|
||||||
super(null, sd, new SDVariable[]{in});
|
super(null, sd, new SDVariable[]{in});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Trace(@NonNull INDArray in){
|
||||||
|
super(wrapOrNull(in), null);
|
||||||
|
}
|
||||||
|
|
||||||
public Trace(){ }
|
public Trace(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -46,7 +46,14 @@ public class XwPlusB extends DynamicCustomOp {
|
||||||
|
|
||||||
public XwPlusB(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) {
|
public XwPlusB(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) {
|
||||||
super(null, sameDiff, new SDVariable[] {input, weights, bias}, false);
|
super(null, sameDiff, new SDVariable[] {input, weights, bias}, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public XwPlusB(INDArray input, INDArray weights, INDArray bias) {
|
||||||
|
super(new INDArray[] {input, weights, bias}, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public XwPlusB(INDArray[] inputs, INDArray output){
|
||||||
|
super(inputs, wrapOrNull(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
@ -35,6 +36,10 @@ public class SigmoidDerivative extends DynamicCustomOp {
|
||||||
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SigmoidDerivative(@NonNull INDArray x, @NonNull INDArray y) {
|
||||||
|
this(x, y, null);
|
||||||
|
}
|
||||||
|
|
||||||
public SigmoidDerivative(INDArray x, INDArray y, INDArray z) {
|
public SigmoidDerivative(INDArray x, INDArray y, INDArray z) {
|
||||||
super(null, new INDArray[]{x,y}, new INDArray[]{z}, null, (int[])null);
|
super(null, new INDArray[]{x,y}, new INDArray[]{z}, null, (int[])null);
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,10 @@ public class SoftmaxBp extends DynamicCustomOp {
|
||||||
addIArgument(dimension);
|
addIArgument(dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, Integer dimension){
|
||||||
|
this(input, grad, null, dimension);
|
||||||
|
}
|
||||||
|
|
||||||
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){
|
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){
|
||||||
super(new INDArray[]{input, grad}, wrapOrNull(output));
|
super(new INDArray[]{input, grad}, wrapOrNull(output));
|
||||||
if(dimension != null)
|
if(dimension != null)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -41,6 +42,10 @@ public class MergeAddOp extends BaseDynamicTransformOp {
|
||||||
super(sameDiff, args, inPlace);
|
super(sameDiff, args, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MergeAddOp(@NonNull INDArray... inputs){
|
||||||
|
this(inputs, null);
|
||||||
|
}
|
||||||
|
|
||||||
public MergeAddOp(INDArray[] inputs, INDArray[] outputs) {
|
public MergeAddOp(INDArray[] inputs, INDArray[] outputs) {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -47,6 +48,10 @@ public class RandomExponential extends DynamicCustomOp {
|
||||||
addTArgument(lambda);
|
addTArgument(lambda);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public RandomExponential(double lambda, DataType datatype, long... shape){
|
||||||
|
this(Nd4j.createFromArray(shape), Nd4j.createUninitialized(datatype, shape), lambda);
|
||||||
|
}
|
||||||
|
|
||||||
public RandomExponential(INDArray shape,INDArray out, double lambda){
|
public RandomExponential(INDArray shape,INDArray out, double lambda){
|
||||||
super(null, new INDArray[]{shape}, new INDArray[]{out}, Collections.singletonList(lambda), (List<Integer>)null);
|
super(null, new INDArray[]{shape}, new INDArray[]{out}, Collections.singletonList(lambda), (List<Integer>)null);
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
@ -49,6 +50,10 @@ public class BernoulliDistribution extends BaseRandomOp {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BernoulliDistribution(double p, DataType datatype, long... shape){
|
||||||
|
this(Nd4j.createUninitialized(datatype, shape), p);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op fills Z with bernoulli trial results, so 0, or 1, depending by common probability
|
* This op fills Z with bernoulli trial results, so 0, or 1, depending by common probability
|
||||||
* @param z
|
* @param z
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
@ -46,6 +47,10 @@ public class BinomialDistribution extends BaseRandomOp {
|
||||||
this.extraArgs = new Object[] {(double) this.trials, this.probability};
|
this.extraArgs = new Object[] {(double) this.trials, this.probability};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BinomialDistribution(int trials, double probability, DataType dt, long[] shape){
|
||||||
|
this(Nd4j.createUninitialized(dt, shape), trials, probability);
|
||||||
|
}
|
||||||
|
|
||||||
public BinomialDistribution() {
|
public BinomialDistribution() {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
@ -50,6 +51,10 @@ public class GaussianDistribution extends BaseRandomOp {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public GaussianDistribution(double mean, double stddev, DataType datatype, long... shape){
|
||||||
|
this(Nd4j.createUninitialized(datatype, shape), mean, stddev);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op fills Z with random values within stddev..mean..stddev boundaries
|
* This op fills Z with random values within stddev..mean..stddev boundaries
|
||||||
* @param z
|
* @param z
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
@ -50,6 +51,10 @@ public class LogNormalDistribution extends BaseRandomOp {
|
||||||
this.extraArgs = new Object[] {this.mean, this.stddev};
|
this.extraArgs = new Object[] {this.mean, this.stddev};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LogNormalDistribution(double mean, double stddev, DataType datatype, long... shape){
|
||||||
|
this(Nd4j.createUninitialized(datatype, shape), mean, stddev);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op fills Z with random values within stddev..mean..stddev boundaries
|
* This op fills Z with random values within stddev..mean..stddev boundaries
|
||||||
* @param z
|
* @param z
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -48,6 +49,10 @@ public class TruncatedNormalDistribution extends BaseRandomOp {
|
||||||
this.extraArgs = new Object[] {this.mean, this.stddev};
|
this.extraArgs = new Object[] {this.mean, this.stddev};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TruncatedNormalDistribution(double mean, double stddev, DataType datatype, long... shape){
|
||||||
|
this(Nd4j.createUninitialized(datatype, shape), mean, stddev);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op fills Z with random values within stddev..mean..stddev boundaries
|
* This op fills Z with random values within stddev..mean..stddev boundaries
|
||||||
* @param z
|
* @param z
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -46,6 +47,10 @@ public class UniformDistribution extends BaseRandomOp {
|
||||||
this.extraArgs = new Object[] {this.from, this.to};
|
this.extraArgs = new Object[] {this.from, this.to};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public UniformDistribution(double min, double max, DataType datatype, long... shape){
|
||||||
|
this(Nd4j.createUninitialized(datatype, shape), min, max);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op fills Z with random values within from...to boundaries
|
* This op fills Z with random values within from...to boundaries
|
||||||
* @param z
|
* @param z
|
||||||
|
|
|
@ -0,0 +1,236 @@
|
||||||
|
/* *****************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.factory;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
public class NDValidation {
|
||||||
|
|
||||||
|
private NDValidation() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on a numerical INDArray (not boolean or utf8).
|
||||||
|
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v Variable to perform operation on
|
||||||
|
*/
|
||||||
|
public static void validateNumerical(String opName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
|
||||||
|
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-numerical data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on numerical INDArrays (not boolean or utf8).
|
||||||
|
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v Variable to perform operation on
|
||||||
|
*/
|
||||||
|
public static void validateNumerical(String opName, INDArray[] v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
for (int i = 0; i < v.length; i++) {
|
||||||
|
if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8)
|
||||||
|
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input array " + i + " with non-numerical data type " + v[i].dataType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on a numerical INDArray (not boolean or utf8).
|
||||||
|
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateNumerical(String opName, String inputName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
|
||||||
|
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type;" +
|
||||||
|
" got array with non-integer data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on numerical INDArrays (not boolean or utf8).
|
||||||
|
* Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v Variable to perform operation on
|
||||||
|
*/
|
||||||
|
public static void validateNumerical(String opName, String inputName, INDArray[] v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
for (int i = 0; i < v.length; i++) {
|
||||||
|
if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8)
|
||||||
|
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input \"" + inputName + "\" array " + i + " with non-numerical data type " + v[i].dataType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on numerical INDArrays (not boolean or utf8).
|
||||||
|
* Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v1 Variable to validate datatype for (input to operation)
|
||||||
|
* @param v2 Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateNumerical(String opName, INDArray v1, INDArray v2) {
|
||||||
|
if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8)
|
||||||
|
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on arrays if one or both variables" +
|
||||||
|
" are non-numerical: got " + v1.dataType() + " and " + v2.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on an integer type INDArray
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateInteger(String opName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (!v.dataType().isIntType())
|
||||||
|
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-integer data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on an integer type INDArray
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param inputName Name of the input to the op to validate
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateInteger(String opName, String inputName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (!v.dataType().isIntType())
|
||||||
|
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" +
|
||||||
|
" type; got array with non-integer data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on an floating point type INDArray
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateFloatingPoint(String opName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (!v.dataType().isFPType())
|
||||||
|
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-floating point data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on a floating point type INDArray
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param inputName Name of the input to the op to validate
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateFloatingPoint(String opName, String inputName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (!v.dataType().isFPType())
|
||||||
|
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName +
|
||||||
|
"\" must be an floating point type; got array with non-floating point data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on a boolean type INDArray
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateBool(String opName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (v.dataType() != DataType.BOOL)
|
||||||
|
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-boolean point data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on a boolean type INDArray
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param inputName Name of the input to the op to validate
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateBool(String opName, String inputName, INDArray v) {
|
||||||
|
if (v == null)
|
||||||
|
return;
|
||||||
|
if (v.dataType() != DataType.BOOL)
|
||||||
|
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName +
|
||||||
|
"\" must be an boolean variable; got array with non-boolean data type " + v.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on boolean INDArrays
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param v1 Variable to validate datatype for (input to operation)
|
||||||
|
* @param v2 Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateBool(String opName, INDArray v1, INDArray v2) {
|
||||||
|
if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL)
|
||||||
|
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on array if one or both variables are non-boolean: "
|
||||||
|
+ v1.dataType() + " and " + v2.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate that the operation is being applied on array with the exact same datatypes (which may optionally be
|
||||||
|
* restricted to numerical INDArrays only (not boolean or utf8))
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param numericalOnly If true, the variables must all be the same type, and must be numerical (not boolean/utf8)
|
||||||
|
* @param vars Variable to perform operation on
|
||||||
|
*/
|
||||||
|
public static void validateSameType(String opName, boolean numericalOnly, INDArray... vars) {
|
||||||
|
if (vars.length == 0)
|
||||||
|
return;
|
||||||
|
if (vars.length == 1) {
|
||||||
|
if (numericalOnly) {
|
||||||
|
validateNumerical(opName, vars[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
DataType first = vars[0].dataType();
|
||||||
|
if (numericalOnly)
|
||||||
|
validateNumerical(opName, vars[0]);
|
||||||
|
for (int i = 1; i < vars.length; i++) {
|
||||||
|
if (first != vars[i].dataType()) {
|
||||||
|
DataType[] dtypes = new DataType[vars.length];
|
||||||
|
for (int j = 0; j < vars.length; j++) {
|
||||||
|
dtypes[j] = vars[j].dataType();
|
||||||
|
}
|
||||||
|
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to arrays with different datatypes:" +
|
||||||
|
" Got arrays with datatypes " + Arrays.toString(dtypes));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean isSameType(INDArray x, INDArray y) {
|
||||||
|
return x.dataType() == y.dataType();
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,10 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.factory;
|
package org.nd4j.linalg.factory;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.factory.ops.NDBitwise;
|
||||||
|
import org.nd4j.linalg.factory.ops.NDMath;
|
||||||
|
import org.nd4j.linalg.factory.ops.NDNN;
|
||||||
|
import org.nd4j.linalg.factory.ops.NDRandom;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.nd4j.shade.guava.primitives.Longs;
|
import org.nd4j.shade.guava.primitives.Longs;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -114,6 +118,51 @@ import java.util.logging.Logger;
|
||||||
*/
|
*/
|
||||||
public class Nd4j {
|
public class Nd4j {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise namespace - operations related to bitwise manipulation of arrays
|
||||||
|
*/
|
||||||
|
public static final NDBitwise bitwise = new NDBitwise();
|
||||||
|
/**
|
||||||
|
* Math namespace - general mathematical operations
|
||||||
|
*/
|
||||||
|
public static final NDMath math = new NDMath();
|
||||||
|
/**
|
||||||
|
* Random namespace - (pseudo) random number generation methods
|
||||||
|
*/
|
||||||
|
public static final NDRandom random = new NDRandom();
|
||||||
|
/**
|
||||||
|
* Neural network namespace - operations related to neural networks
|
||||||
|
*/
|
||||||
|
public static final NDNN nn = new NDNN();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise namespace - operations related to bitwise manipulation of arrays
|
||||||
|
*/
|
||||||
|
public static NDBitwise bitwise() {
|
||||||
|
return bitwise;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Math namespace - general mathematical operations
|
||||||
|
*/
|
||||||
|
public static NDMath math() {
|
||||||
|
return math;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random namespace - (pseudo) random number generation methods
|
||||||
|
*/
|
||||||
|
public static NDRandom random() {
|
||||||
|
return random;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Neural network namespace - operations related to neural networks
|
||||||
|
*/
|
||||||
|
public static NDNN nn() {
|
||||||
|
return nn;
|
||||||
|
}
|
||||||
|
|
||||||
private final static String DATA_BUFFER_OPS = "databufferfactory";
|
private final static String DATA_BUFFER_OPS = "databufferfactory";
|
||||||
private final static String CONVOLUTION_OPS = "convops";
|
private final static String CONVOLUTION_OPS = "convops";
|
||||||
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
|
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
|
||||||
|
@ -2638,7 +2687,7 @@ public class Nd4j {
|
||||||
INDArray ret;
|
INDArray ret;
|
||||||
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
|
if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) {
|
||||||
ret = Nd4j.create(x.dataType(), x.length(), x.length());
|
ret = Nd4j.create(x.dataType(), x.length(), x.length());
|
||||||
Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret}));
|
Nd4j.getExecutioner().execAndReturn(new Diag(x, ret));
|
||||||
} else {
|
} else {
|
||||||
ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1)));
|
ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1)));
|
||||||
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
|
Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret));
|
||||||
|
|
|
@ -0,0 +1,211 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class NDBitwise {
|
||||||
|
public NDBitwise() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise AND operation. Supports broadcasting.<br>
|
||||||
|
*
|
||||||
|
* Inputs must satisfy the following constraints: <br>
|
||||||
|
* Must be same types: isSameType(x, y)<br>
|
||||||
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
|
*
|
||||||
|
* @param x First input array (INT type)
|
||||||
|
* @param y Second input array (INT type)
|
||||||
|
* @return output Bitwise AND array (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray and(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("and", "x", x);
|
||||||
|
NDValidation.validateInteger("and", "y", y);
|
||||||
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(x, y))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
|
||||||
|
*
|
||||||
|
* @param x Input 1 (INT type)
|
||||||
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
* @return output SDVariable with shifted bits (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray bitRotl(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateInteger("bitRotl", "x", x);
|
||||||
|
NDValidation.validateInteger("bitRotl", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
|
||||||
|
*
|
||||||
|
* @param x Input 1 (INT type)
|
||||||
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
* @return output SDVariable with shifted bits (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray bitRotr(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateInteger("bitRotr", "x", x);
|
||||||
|
NDValidation.validateInteger("bitRotr", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shift integer bits to the left, i.e. var << 4<br>
|
||||||
|
*
|
||||||
|
* @param x Input 1 (INT type)
|
||||||
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
* @return output SDVariable with shifted bits (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray bitShift(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateInteger("bitShift", "x", x);
|
||||||
|
NDValidation.validateInteger("bitShift", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shift integer bits to the right, i.e. var >> 4<br>
|
||||||
|
*
|
||||||
|
* @param x Input 1 (INT type)
|
||||||
|
* @param shift Number of bits to shift. (INT type)
|
||||||
|
* @return output SDVariable with shifted bits (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray bitShiftRight(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateInteger("bitShiftRight", "x", x);
|
||||||
|
NDValidation.validateInteger("bitShiftRight", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
||||||
|
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
|
||||||
|
*
|
||||||
|
* Inputs must satisfy the following constraints: <br>
|
||||||
|
* Must be same types: isSameType(x, y)<br>
|
||||||
|
*
|
||||||
|
* @param x First input array. (INT type)
|
||||||
|
* @param y Second input array. (INT type)
|
||||||
|
* @return output bitwise Hamming distance (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray bitsHammingDistance(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("bitsHammingDistance", "x", x);
|
||||||
|
NDValidation.validateInteger("bitsHammingDistance", "y", y);
|
||||||
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(x, y))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise left shift operation. Supports broadcasting.<br>
|
||||||
|
*
|
||||||
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
* @param y Amount to shift elements of x array (INT type)
|
||||||
|
* @return output Bitwise shifted input x (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray leftShift(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("leftShift", "x", x);
|
||||||
|
NDValidation.validateInteger("leftShift", "y", y);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, y))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||||
|
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||||
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||||
|
*
|
||||||
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
* @param y Amount to shift elements of x array (INT type)
|
||||||
|
* @return output Bitwise cyclic shifted input x (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray leftShiftCyclic(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("leftShiftCyclic", "x", x);
|
||||||
|
NDValidation.validateInteger("leftShiftCyclic", "y", y);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, y))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise OR operation. Supports broadcasting.<br>
|
||||||
|
*
|
||||||
|
* Inputs must satisfy the following constraints: <br>
|
||||||
|
* Must be same types: isSameType(x, y)<br>
|
||||||
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
|
*
|
||||||
|
* @param x First input array (INT type)
|
||||||
|
* @param y First input array (INT type)
|
||||||
|
* @return output Bitwise OR array (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray or(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("or", "x", x);
|
||||||
|
NDValidation.validateInteger("or", "y", y);
|
||||||
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(x, y))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise right shift operation. Supports broadcasting. <br>
|
||||||
|
*
|
||||||
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
* @param y Amount to shift elements of x array (INT type)
|
||||||
|
* @return output Bitwise shifted input x (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray rightShift(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("rightShift", "x", x);
|
||||||
|
NDValidation.validateInteger("rightShift", "y", y);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, y))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||||
|
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||||
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||||
|
*
|
||||||
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
* @param y Amount to shift elements of x array (INT type)
|
||||||
|
* @return output Bitwise cyclic shifted input x (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray rightShiftCyclic(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("rightShiftCyclic", "x", x);
|
||||||
|
NDValidation.validateInteger("rightShiftCyclic", "y", y);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, y))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
|
||||||
|
*
|
||||||
|
* Inputs must satisfy the following constraints: <br>
|
||||||
|
* Must be same types: isSameType(x, y)<br>
|
||||||
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||||
|
*
|
||||||
|
* @param x First input array (INT type)
|
||||||
|
* @param y First input array (INT type)
|
||||||
|
* @return output Bitwise XOR array (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray xor(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("xor", "x", x);
|
||||||
|
NDValidation.validateInteger("xor", "y", y);
|
||||||
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(x, y))[0];
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,522 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class NDNN {
|
||||||
|
public NDNN() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Neural network batch normalization operation.<br>
|
||||||
|
* For details, see <a href="https://arxiv.org/abs/1502.03167">https://arxiv.org/abs/1502.03167</a><br>
|
||||||
|
*
|
||||||
|
* @param input Input variable. (NUMERIC type)
|
||||||
|
* @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type)
|
||||||
|
* @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type)
|
||||||
|
* @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type)
|
||||||
|
* @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type)
|
||||||
|
* @param epsilon Epsilon constant for numerical stability (to avoid division by 0)
|
||||||
|
* @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations.
|
||||||
|
* For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC
|
||||||
|
* For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1))
|
||||||
|
* @return output variable for batch normalization (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma,
|
||||||
|
INDArray beta, double epsilon, int... axis) {
|
||||||
|
NDValidation.validateNumerical("batchNorm", "input", input);
|
||||||
|
NDValidation.validateNumerical("batchNorm", "mean", mean);
|
||||||
|
NDValidation.validateNumerical("batchNorm", "variance", variance);
|
||||||
|
NDValidation.validateNumerical("batchNorm", "gamma", gamma);
|
||||||
|
NDValidation.validateNumerical("batchNorm", "beta", beta);
|
||||||
|
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector<br>
|
||||||
|
*
|
||||||
|
* @param input 4d input variable (NUMERIC type)
|
||||||
|
* @param bias 1d bias (NUMERIC type)
|
||||||
|
* @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels].
|
||||||
|
* Unused for 2d inputs
|
||||||
|
* @return output Output variable, after applying bias add operation (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray biasAdd(INDArray input, INDArray bias, boolean nchw) {
|
||||||
|
NDValidation.validateNumerical("biasAdd", "input", input);
|
||||||
|
NDValidation.validateNumerical("biasAdd", "bias", bias);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation performs dot product attention on the given timeseries input with the given queries<br>
|
||||||
|
* out = sum(similarity(k_i, q) * v_i)<br>
|
||||||
|
* <br>
|
||||||
|
* similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q<br>
|
||||||
|
* <br>
|
||||||
|
* Optionally with normalization step:<br>
|
||||||
|
* similarity(k, q) = softmax(k * q / sqrt(size(q))<br>
|
||||||
|
* <br>
|
||||||
|
* See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)<br>
|
||||||
|
* <br>
|
||||||
|
* Note: This supports multiple queries at once, if only one query is available the queries vector still has to<br>
|
||||||
|
* be 3D but can have queryCount = 1<br>
|
||||||
|
* <br>
|
||||||
|
* Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for<br>
|
||||||
|
* both.<br>
|
||||||
|
* <br>
|
||||||
|
* Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The<br>
|
||||||
|
* output rank will depend on the input rank.<br>
|
||||||
|
*
|
||||||
|
* @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount]
|
||||||
|
* or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type)
|
||||||
|
* @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps]
|
||||||
|
* or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type)
|
||||||
|
* @param values input 3D array "values" of shape [batchSize, featureValues, timesteps]
|
||||||
|
* or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type)
|
||||||
|
* @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type)
|
||||||
|
* @param scaled normalization, false -> do not apply normalization, true -> apply normalization
|
||||||
|
* @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount],
|
||||||
|
* (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray dotProductAttention(INDArray queries, INDArray keys, INDArray values,
|
||||||
|
INDArray mask, boolean scaled) {
|
||||||
|
NDValidation.validateNumerical("dotProductAttention", "queries", queries);
|
||||||
|
NDValidation.validateNumerical("dotProductAttention", "keys", keys);
|
||||||
|
NDValidation.validateNumerical("dotProductAttention", "values", values);
|
||||||
|
NDValidation.validateNumerical("dotProductAttention", "mask", mask);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dropout operation<br>
|
||||||
|
*
|
||||||
|
* @param input Input array (NUMERIC type)
|
||||||
|
* @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p)
|
||||||
|
* @return output Output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray dropout(INDArray input, double inputRetainProbability) {
|
||||||
|
NDValidation.validateNumerical("dropout", "input", input);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.DropOut(input, inputRetainProbability));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise exponential linear unit (ELU) function:<br>
|
||||||
|
* out = x if x > 0<br>
|
||||||
|
* out = a * (exp(x) - 1) if x <= 0<br>
|
||||||
|
* with constant a = 1.0<br>
|
||||||
|
* <p><br>
|
||||||
|
* See: <a href="https://arxiv.org/abs/1511.07289">https://arxiv.org/abs/1511.07289</a><br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray elu(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("elu", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GELU activation function - Gaussian Error Linear Units<br>
|
||||||
|
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a><br>
|
||||||
|
* This method uses the sigmoid approximation<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray gelu(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("gelu", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise hard sigmoid function:<br>
|
||||||
|
* out[i] = 0 if in[i] <= -2.5<br>
|
||||||
|
* out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5<br>
|
||||||
|
* out[i] = 1 if in[i] >= 2.5<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray hardSigmoid(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("hardSigmoid", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise hard tanh function:<br>
|
||||||
|
* out[i] = -1 if in[i] <= -1<br>
|
||||||
|
* out[1] = in[i] if -1 < in[i] < 1<br>
|
||||||
|
* out[i] = 1 if in[i] >= 1<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray hardTanh(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("hardTanh", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray hardTanhDerivative(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("hardTanhDerivative", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply Layer Normalization<br>
|
||||||
|
* <br>
|
||||||
|
* y = gain * standardize(x) + bias<br>
|
||||||
|
*
|
||||||
|
* @param input Input variable (NUMERIC type)
|
||||||
|
* @param gain Gain (NUMERIC type)
|
||||||
|
* @param bias Bias (NUMERIC type)
|
||||||
|
* @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data
|
||||||
|
* @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1))
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray layerNorm(INDArray input, INDArray gain, INDArray bias, boolean channelsFirst,
|
||||||
|
int... dimensions) {
|
||||||
|
NDValidation.validateNumerical("layerNorm", "input", input);
|
||||||
|
NDValidation.validateNumerical("layerNorm", "gain", gain);
|
||||||
|
NDValidation.validateNumerical("layerNorm", "bias", bias);
|
||||||
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply Layer Normalization<br>
|
||||||
|
* <br>
|
||||||
|
* y = gain * standardize(x) + bias<br>
|
||||||
|
*
|
||||||
|
* @param input Input variable (NUMERIC type)
|
||||||
|
* @param gain Gain (NUMERIC type)
|
||||||
|
* @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data
|
||||||
|
* @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1))
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray layerNorm(INDArray input, INDArray gain, boolean channelsFirst,
|
||||||
|
int... dimensions) {
|
||||||
|
NDValidation.validateNumerical("layerNorm", "input", input);
|
||||||
|
NDValidation.validateNumerical("layerNorm", "gain", gain);
|
||||||
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, channelsFirst, dimensions))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise leaky ReLU function:<br>
|
||||||
|
* out = x if x >= 0.0<br>
|
||||||
|
* out = alpha * x if x < cutoff<br>
|
||||||
|
* Alpha value is most commonly set to 0.01<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @param alpha Cutoff - commonly 0.01 (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray leakyRelu(INDArray x, INDArray alpha) {
|
||||||
|
NDValidation.validateNumerical("leakyRelu", "x", x);
|
||||||
|
NDValidation.validateNumerical("leakyRelu", "alpha", alpha);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Leaky ReLU derivative: dOut/dIn given input.<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @param alpha Cutoff - commonly 0.01 (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray leakyReluDerivative(INDArray x, INDArray alpha) {
|
||||||
|
NDValidation.validateNumerical("leakyReluDerivative", "x", x);
|
||||||
|
NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Linear layer operation: out = mmul(in,w) + bias<br>
|
||||||
|
* Note that bias array is optional<br>
|
||||||
|
*
|
||||||
|
* @param input Input data (NUMERIC type)
|
||||||
|
* @param weights Weights variable, shape [nIn, nOut] (NUMERIC type)
|
||||||
|
* @param bias Optional bias variable (may be null) (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray linear(INDArray input, INDArray weights, INDArray bias) {
|
||||||
|
NDValidation.validateNumerical("linear", "input", input);
|
||||||
|
NDValidation.validateNumerical("linear", "weights", weights);
|
||||||
|
NDValidation.validateNumerical("linear", "bias", bias);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logSigmoid(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("logSigmoid", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log softmax activation<br>
|
||||||
|
*
|
||||||
|
* @param x (NUMERIC type)
|
||||||
|
* @return output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logSoftmax(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("logSoftmax", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log softmax activation<br>
|
||||||
|
*
|
||||||
|
* @param x Input (NUMERIC type)
|
||||||
|
* @param dimension Dimension along which to apply log softmax
|
||||||
|
* @return output Output - log(softmax(input)) (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logSoftmax(INDArray x, int dimension) {
|
||||||
|
NDValidation.validateNumerical("logSoftmax", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This performs multi-headed dot product attention on the given timeseries input<br>
|
||||||
|
* out = concat(head_1, head_2, ..., head_n) * Wo<br>
|
||||||
|
* head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)<br>
|
||||||
|
* <br>
|
||||||
|
* Optionally with normalization when calculating the attention for each head.<br>
|
||||||
|
* <br>
|
||||||
|
* See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")<br>
|
||||||
|
* <br>
|
||||||
|
* This makes use of dot_product_attention OP support for rank 4 inputs.<br>
|
||||||
|
* see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)<br>
|
||||||
|
*
|
||||||
|
* @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type)
|
||||||
|
* @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type)
|
||||||
|
* @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type)
|
||||||
|
* @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type)
|
||||||
|
* @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type)
|
||||||
|
* @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type)
|
||||||
|
* @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type)
|
||||||
|
* @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type)
|
||||||
|
* @param scaled normalization, false -> do not apply normalization, true -> apply normalization
|
||||||
|
* @return output Attention result arrays of shape [batchSize, outSize, queryCount]
|
||||||
|
* (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, INDArray values,
|
||||||
|
INDArray Wq, INDArray Wk, INDArray Wv, INDArray Wo, INDArray mask, boolean scaled) {
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries);
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys);
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "values", values);
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq);
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk);
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv);
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo);
|
||||||
|
NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br>
|
||||||
|
* out[i] = in[i] if in[i] >= 0<br>
|
||||||
|
* out[i] = in[i] * alpha[i] otherwise<br>
|
||||||
|
* <br>
|
||||||
|
* sharedAxes allows you to share learnable parameters along axes.<br>
|
||||||
|
* For example, if the input has shape [batchSize, channels, height, width]<br>
|
||||||
|
* and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an<br>
|
||||||
|
* alpha with shape [channels].<br>
|
||||||
|
*
|
||||||
|
* @param input Input data (NUMERIC type)
|
||||||
|
* @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type)
|
||||||
|
* @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1))
|
||||||
|
* @return output Output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray prelu(INDArray input, INDArray alpha, int... sharedAxes) {
|
||||||
|
NDValidation.validateNumerical("prelu", "input", input);
|
||||||
|
NDValidation.validateNumerical("prelu", "alpha", alpha);
|
||||||
|
Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise rectified linear function with specified cutoff:<br>
|
||||||
|
* out[i] = in[i] if in[i] >= cutoff<br>
|
||||||
|
* out[i] = 0 otherwise<br>
|
||||||
|
*
|
||||||
|
* @param x Input (NUMERIC type)
|
||||||
|
* @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0
|
||||||
|
* @return output Output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray relu(INDArray x, double cutoff) {
|
||||||
|
NDValidation.validateNumerical("relu", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(x, cutoff));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise "rectified linear 6" function with specified cutoff:<br>
|
||||||
|
* out[i] = min(max(in, cutoff), 6)<br>
|
||||||
|
*
|
||||||
|
* @param x Input (NUMERIC type)
|
||||||
|
* @param cutoff Cutoff value for ReLU operation. Usually 0
|
||||||
|
* @return output Output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray relu6(INDArray x, double cutoff) {
|
||||||
|
NDValidation.validateNumerical("relu6", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Relu6(x, cutoff));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)<br>
|
||||||
|
* Note that bias array is optional<br>
|
||||||
|
*
|
||||||
|
* @param input Input data (NUMERIC type)
|
||||||
|
* @param weights Weights variable (NUMERIC type)
|
||||||
|
* @param bias Optional bias variable (may be null) (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray reluLayer(INDArray input, INDArray weights, INDArray bias) {
|
||||||
|
NDValidation.validateNumerical("reluLayer", "input", input);
|
||||||
|
NDValidation.validateNumerical("reluLayer", "weights", weights);
|
||||||
|
NDValidation.validateNumerical("reluLayer", "bias", bias);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise SeLU function - Scaled exponential Lineal Unit: see <a href="https://arxiv.org/abs/1706.02515">Self-Normalizing Neural Networks</a><br>
|
||||||
|
* <br>
|
||||||
|
* out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0<br>
|
||||||
|
* Uses default scale and alpha values.<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray selu(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("selu", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray sigmoid(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("sigmoid", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut<br>
|
||||||
|
*
|
||||||
|
* @param x Input Variable (NUMERIC type)
|
||||||
|
* @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type)
|
||||||
|
* @return output Output (gradient at input of sigmoid) (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray sigmoidDerivative(INDArray x, INDArray wrt) {
|
||||||
|
NDValidation.validateNumerical("sigmoidDerivative", "x", x);
|
||||||
|
NDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Softmax activation, along the specified dimension<br>
|
||||||
|
*
|
||||||
|
* @param x Input (NUMERIC type)
|
||||||
|
* @param dimension Dimension along which to apply softmax
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softmax(INDArray x, int dimension) {
|
||||||
|
NDValidation.validateNumerical("softmax", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Softmax derivative function<br>
|
||||||
|
*
|
||||||
|
* @param x Softmax input (NUMERIC type)
|
||||||
|
* @param wrt Gradient at output, dL/dx (NUMERIC type)
|
||||||
|
* @param dimension Softmax dimension
|
||||||
|
* @return output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softmaxDerivative(INDArray x, INDArray wrt, int dimension) {
|
||||||
|
NDValidation.validateNumerical("softmaxDerivative", "x", x);
|
||||||
|
NDValidation.validateNumerical("softmaxDerivative", "wrt", wrt);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(x, wrt, dimension))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise softplus function: out = log(exp(x) + 1)<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softplus(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("softplus", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise softsign function: out = x / (abs(x) + 1)<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softsign(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("softsign", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softsignDerivative(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("softsignDerivative", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0<br>
|
||||||
|
* See: <a href="https://arxiv.org/abs/1710.05941">https://arxiv.org/abs/1710.05941</a><br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray swish(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("swish", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,138 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class NDRandom {
|
||||||
|
public NDRandom() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,<br>
|
||||||
|
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability<br>
|
||||||
|
* 1-P.<br>
|
||||||
|
*
|
||||||
|
* @param p Probability of value 1
|
||||||
|
* @param datatype Data type of the output variable
|
||||||
|
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||||
|
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray bernoulli(double p, DataType datatype, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(p, datatype, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,<br>
|
||||||
|
* with the specified number of trials and probability.<br>
|
||||||
|
*
|
||||||
|
* @param nTrials Number of trials parameter for the binomial distribution
|
||||||
|
* @param p Probability of success for each trial
|
||||||
|
* @param datatype Data type of the output variable
|
||||||
|
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||||
|
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray binomial(int nTrials, double p, DataType datatype, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(nTrials, p, datatype, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:<br>
|
||||||
|
* P(x) = lambda * exp(-lambda * x)<br>
|
||||||
|
*
|
||||||
|
* Inputs must satisfy the following constraints: <br>
|
||||||
|
* Must be positive: lambda > 0<br>
|
||||||
|
*
|
||||||
|
* @param lambda lambda parameter
|
||||||
|
* @param datatype Data type of the output variable
|
||||||
|
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||||
|
*/
|
||||||
|
public INDArray[] exponential(double lambda, DataType datatype, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
Preconditions.checkArgument(lambda > 0, "Must be positive");
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,<br>
|
||||||
|
* i.e., {@code log(x) ~ N(mean, stdev)}<br>
|
||||||
|
*
|
||||||
|
* @param mean Mean value for the random array
|
||||||
|
* @param stddev Standard deviation for the random array
|
||||||
|
* @param datatype Data type of the output variable
|
||||||
|
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||||
|
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logNormal(double mean, double stddev, DataType datatype, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(mean, stddev, datatype, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
|
||||||
|
* N(mean, stdev)<br>
|
||||||
|
*
|
||||||
|
* @param mean Mean value for the random array
|
||||||
|
* @param stddev Standard deviation for the random array
|
||||||
|
* @param datatype Data type of the output variable
|
||||||
|
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||||
|
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray normal(double mean, double stddev, DataType datatype, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(mean, stddev, datatype, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
|
||||||
|
* N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
|
||||||
|
*
|
||||||
|
* @param mean Mean value for the random array
|
||||||
|
* @param stddev Standard deviation for the random array
|
||||||
|
* @param datatype Data type of the output variable
|
||||||
|
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||||
|
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray normalTruncated(double mean, double stddev, DataType datatype, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(mean, stddev, datatype, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,<br>
|
||||||
|
* U(min,max)<br>
|
||||||
|
*
|
||||||
|
* @param min Minimum value
|
||||||
|
* @param max Maximum value.
|
||||||
|
* @param datatype Data type of the output variable
|
||||||
|
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||||
|
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray uniform(double min, double max, DataType datatype, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(min, max, datatype, shape));
|
||||||
|
}
|
||||||
|
}
|
|
@ -1620,11 +1620,11 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
switch (i) {
|
switch (i) {
|
||||||
case 0:
|
case 0:
|
||||||
t = sd.math().isNonDecreasing(in1);
|
t = sd.math().isNonDecreasing(in1);
|
||||||
Nd4j.exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut}));
|
Nd4j.exec(new IsNonDecreasing(ia, expOut));
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
t = sd.math().isStrictlyIncreasing(in1);
|
t = sd.math().isStrictlyIncreasing(in1);
|
||||||
Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut}));
|
Nd4j.exec(new IsStrictlyIncreasing(ia, expOut));
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
t = sd.isNumericTensor(in1);
|
t = sd.isNumericTensor(in1);
|
||||||
|
@ -1650,7 +1650,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
INDArray ia = Nd4j.randn(minibatch, nOut);
|
INDArray ia = Nd4j.randn(minibatch, nOut);
|
||||||
INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape());
|
INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape());
|
||||||
|
|
||||||
Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut}));
|
Nd4j.exec(new IsStrictlyIncreasing(ia, expOut));
|
||||||
System.out.println(expOut);
|
System.out.println(expOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
||||||
import org.nd4j.autodiff.execution.conf.OutputMode;
|
import org.nd4j.autodiff.execution.conf.OutputMode;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.listeners.Listener;
|
import org.nd4j.autodiff.listeners.Listener;
|
||||||
|
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package org.nd4j.linalg.api;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
|
public class TestNamespaces extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public TestNamespaces(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBitwiseSimple(){
|
||||||
|
|
||||||
|
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
|
||||||
|
INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
|
||||||
|
|
||||||
|
INDArray and = Nd4j.bitwise.and(x, y);
|
||||||
|
INDArray or = Nd4j.bitwise.or(x, y);
|
||||||
|
INDArray xor = Nd4j.bitwise.xor(x, y);
|
||||||
|
|
||||||
|
System.out.println(and);
|
||||||
|
System.out.println(or);
|
||||||
|
System.out.println(xor);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMathSimple(){
|
||||||
|
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
|
||||||
|
INDArray abs = Nd4j.math.abs(x);
|
||||||
|
System.out.println(x);
|
||||||
|
System.out.println(abs);
|
||||||
|
|
||||||
|
|
||||||
|
INDArray c1 = Nd4j.createFromArray(0, 2, 1);
|
||||||
|
INDArray c2 = Nd4j.createFromArray(1, 2, 1);
|
||||||
|
|
||||||
|
INDArray cm = Nd4j.math.confusionMatrix(c1, c2, 3);
|
||||||
|
System.out.println(cm);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRandomSimple(){
|
||||||
|
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
|
||||||
|
System.out.println(normal);
|
||||||
|
INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10);
|
||||||
|
System.out.println(uniform);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNeuralNetworkSimple(){
|
||||||
|
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
|
||||||
|
System.out.println(out);
|
||||||
|
INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1);
|
||||||
|
System.out.println(out2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue