Image namespace (#176)

* created NDImage.java and fixed constructor in AdjustContrast.java

* created NDImage.java and fixed constructor in AdjustContrast.java

* created NDImage.java and fixed constructor in AdjustContrast.java v2

* regenerated NDImage from cleaned Image,kt also cleaned AdjustContrast.java

* draft of NDCNN

* draft of NDCNN

* started NDRNN

* started NDRNN

* looking like finished with namespace

* Regenerate namespaces

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add ND4J namespace methods for new namespaces

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes, cleanup

Signed-off-by: Alex Black <blacka101@gmail.com>

* More fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Andrii Tuzhykov <andrew@unrealists.com>
Co-authored-by: Andrii Tuzhykov <andrew@konduit.ai>
Co-authored-by: AlexDBlack <blacka101@gmail.com>
master
Andrii T 2020-03-09 04:35:17 +02:00 committed by GitHub
parent a80fb99a5f
commit a2ec3dbc97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1239 additions and 81 deletions

View File

@ -33,6 +33,10 @@ public class AdjustContrast extends BaseAdjustContrast {
super(sameDiff,new SDVariable[]{in,factor}); super(sameDiff,new SDVariable[]{in,factor});
} }
public AdjustContrast(@NonNull INDArray in, double factor) {
this(in, factor, null);
}
@Override @Override
public String opName() { public String opName() {
return "adjust_contrast"; return "adjust_contrast";

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -36,6 +37,7 @@ import java.util.*;
*/ */
@NoArgsConstructor @NoArgsConstructor
public class CropAndResize extends DynamicCustomOp { public class CropAndResize extends DynamicCustomOp {
public enum Method {BILINEAR, NEAREST}; public enum Method {BILINEAR, NEAREST};
protected Method method = Method.BILINEAR; protected Method method = Method.BILINEAR;
protected double extrapolationValue = 0.0; protected double extrapolationValue = 0.0;
@ -48,6 +50,7 @@ public class CropAndResize extends DynamicCustomOp {
addArgs(); addArgs();
} }
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue, @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue,
INDArray output){ INDArray output){
@ -62,6 +65,12 @@ public class CropAndResize extends DynamicCustomOp {
outputArguments.add(output); outputArguments.add(output);
} }
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
@NonNull INDArray cropOutSize, double extrapolationValue) {
this(image, cropBoxes, boxIndices, cropOutSize, Method.BILINEAR, extrapolationValue, null);
}
@Override @Override
public String opName() { public String opName() {
return "crop_and_resize"; return "crop_and_resize";

View File

@ -72,6 +72,18 @@ public class ExtractImagePatches extends DynamicCustomOp {
addArgs(); addArgs();
} }
public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
super(new INDArray[]{input},null);
int[] kSises = {kH,kW};
int[] strides = {sH,sW};
int[] rates = {rH, rW};
this.kSizes = kSises;
this.strides = strides;
this.rates = rates;
this.isSameMode = sameMode;
addArgs();
}
@Override @Override
public String opName() { public String opName() {

View File

@ -49,10 +49,15 @@ public class AvgPooling2D extends DynamicCustomOp {
protected Pooling2DConfig config; protected Pooling2DConfig config;
public enum Pooling2DType { public enum Pooling2DType {
MAX, AVG, PNORM, MAX, AVG, PNORM,
} }
public AvgPooling2D(@NonNull INDArray input, Pooling2DConfig config) {
this(input, null, config);
}
@Builder(builderMethodName = "sameDiffBuilder") @Builder(builderMethodName = "sameDiffBuilder")
public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {

View File

@ -17,6 +17,8 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -38,9 +40,8 @@ import java.util.Map;
*/ */
@Slf4j @Slf4j
@Getter @Getter
@NoArgsConstructor
public class AvgPooling3D extends Pooling3D { public class AvgPooling3D extends Pooling3D {
public AvgPooling3D() {
}
public AvgPooling3D(SameDiff sameDiff, SDVariable input, Pooling3DConfig config) { public AvgPooling3D(SameDiff sameDiff, SDVariable input, Pooling3DConfig config) {
super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG); super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG);
@ -50,6 +51,11 @@ public class AvgPooling3D extends Pooling3D {
super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.AVG); super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.AVG);
} }
public AvgPooling3D(@NonNull INDArray input, Pooling3DConfig pooling3DConfig) {
super(null,null,new INDArray[]{input},null,false, pooling3DConfig, Pooling3DType.AVG);
}
@Override @Override
public boolean isConfigProperties() { public boolean isConfigProperties() {
return true; return true;

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -53,7 +54,7 @@ public class Col2Im extends DynamicCustomOp {
addArgs(); addArgs();
} }
public Col2Im(SameDiff sd, SDVariable input, Conv2DConfig config){ public Col2Im(@NonNull SameDiff sd, @NonNull SDVariable input, @NonNull Conv2DConfig config){
super(null, sd, new SDVariable[]{input}); super(null, sd, new SDVariable[]{input});
this.conv2DConfig = config; this.conv2DConfig = config;
addArgs(); addArgs();
@ -61,6 +62,13 @@ public class Col2Im extends DynamicCustomOp {
public Col2Im() {} public Col2Im() {}
public Col2Im(@NonNull INDArray in, @NonNull Conv2DConfig conv2DConfig) {
super("col2Im",in,null,null,null);
this.conv2DConfig = conv2DConfig;
}
protected void addArgs() { protected void addArgs() {
addIArgument(conv2DConfig.getSH()); addIArgument(conv2DConfig.getSH());
addIArgument(conv2DConfig.getSW()); addIArgument(conv2DConfig.getSW());

View File

@ -64,6 +64,14 @@ public class Conv1D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public Conv1D( @NonNull INDArray input, @NonNull INDArray weights, INDArray bias, Conv1DConfig conv1DConfig) {
this(wrapFilterNull(input, weights, bias), null, conv1DConfig);
}
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, Conv1DConfig conv1DConfig) {
this(new INDArray[]{input, weights}, null, conv1DConfig);
}
private void initConfig(Conv1DConfig config){ private void initConfig(Conv1DConfig config){
this.config = config; this.config = config;
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());

View File

@ -75,6 +75,14 @@ public class Conv2D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, @NonNull Conv2DConfig conv2DConfig) {
this(new INDArray[]{layerInput, weights}, null, conv2DConfig);
}
public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, INDArray bias, @NonNull Conv2DConfig conv2DConfig) {
this(wrapFilterNull(layerInput, weights,bias), null, conv2DConfig);
}
protected void initConfig(Conv2DConfig config){ protected void initConfig(Conv2DConfig config){
this.config = config; this.config = config;

View File

@ -70,6 +70,14 @@ public class Conv3D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public Conv3D(@NonNull INDArray input,@NonNull INDArray weights, @NonNull Conv3DConfig conv3DConfig) {
this(new INDArray[]{input, weights}, null, conv3DConfig);
}
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull Conv3DConfig conv3DConfig) {
this(wrapFilterNull(input, weights, bias) , null, conv3DConfig);
}
private void initConfig(Conv3DConfig config){ private void initConfig(Conv3DConfig config){
this.config = config; this.config = config;
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,

View File

@ -73,6 +73,15 @@ public class DeConv2D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DeConv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, DeConv2DConfig deConv2DConfig) {
this(wrapFilterNull(layerInput, weights), null, deConv2DConfig);
}
public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig deConv2DConfig) {
this(wrapFilterNull(layerInput, weights, bias), null, deConv2DConfig);
}
@Override @Override
public long[] iArgs() { public long[] iArgs() {
if (iArguments.size() == 0) if (iArguments.size() == 0)

View File

@ -48,7 +48,7 @@ public class DeConv3D extends DynamicCustomOp {
protected DeConv3DConfig config; protected DeConv3DConfig config;
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { public DeConv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
super(sameDiff, toArr(input, weights, bias)); super(sameDiff, toArr(input, weights, bias));
this.config = config; this.config = config;
addArgs(); addArgs();
@ -65,6 +65,14 @@ public class DeConv3D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, @NonNull DeConv3DConfig deConv3DConfig) {
this(new INDArray[]{input, weights}, null, deConv3DConfig);
}
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull DeConv3DConfig deConv3DConfig) {
this(wrapFilterNull(input, weights, bias), null, deConv3DConfig);
}
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
if(bias != null){ if(bias != null){
return new SDVariable[]{input, weights, bias}; return new SDVariable[]{input, weights, bias};

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -24,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.enums.DataFormat;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -58,7 +60,7 @@ public class DepthToSpace extends DynamicCustomOp {
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }
public DepthToSpace(INDArray in, INDArray out, int blockSize, String dataFormat) { public DepthToSpace(@NonNull INDArray in, INDArray out, int blockSize, @NonNull String dataFormat) {
super(null, in, out, null, null); super(null, in, out, null, null);
this.blockSize = blockSize; this.blockSize = blockSize;
this.dataFormat = dataFormat; this.dataFormat = dataFormat;
@ -66,6 +68,10 @@ public class DepthToSpace extends DynamicCustomOp {
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }
public DepthToSpace(@NonNull INDArray x, int blockSize, DataFormat dataFormat) {
this(x, null, blockSize, dataFormat.toString());
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {

View File

@ -16,11 +16,8 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.*;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -52,6 +49,7 @@ import java.util.*;
*/ */
@Slf4j @Slf4j
@Getter @Getter
@NoArgsConstructor
public class DepthwiseConv2D extends DynamicCustomOp { public class DepthwiseConv2D extends DynamicCustomOp {
protected Conv2DConfig config; protected Conv2DConfig config;
@ -77,7 +75,16 @@ public class DepthwiseConv2D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DepthwiseConv2D() { public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, Conv2DConfig conv2DConfig) {
this(wrapFilterNull(layerInput, depthWeights), null, conv2DConfig);
}
public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig conv2DConfig) {
this(wrapFilterNull(layerInput, depthWeights, bias), null, conv2DConfig);
}
public DepthwiseConv2D(INDArray inputs, Conv2DConfig conv2DConfig) {
this(wrapFilterNull(inputs), null, conv2DConfig);
} }
@Override @Override

View File

@ -58,6 +58,13 @@ public class Im2col extends DynamicCustomOp {
public Im2col() {} public Im2col() {}
public Im2col(INDArray in, Conv2DConfig conv2DConfig) {
super("im2Col",in,null,null,null);
this.conv2DConfig = conv2DConfig;
addArgs();
}
protected void addArgs() { protected void addArgs() {
addIArgument(conv2DConfig.getKH()); addIArgument(conv2DConfig.getKH());
addIArgument(conv2DConfig.getKW()); addIArgument(conv2DConfig.getKW());
@ -68,7 +75,6 @@ public class Im2col extends DynamicCustomOp {
addIArgument(conv2DConfig.getDH()); addIArgument(conv2DConfig.getDH());
addIArgument(conv2DConfig.getDW()); addIArgument(conv2DConfig.getDW());
addIArgument(ArrayUtil.fromBoolean(conv2DConfig.isSameMode())); addIArgument(ArrayUtil.fromBoolean(conv2DConfig.isSameMode()));
} }

View File

@ -65,6 +65,13 @@ public class LocalResponseNormalization extends DynamicCustomOp {
addArgs(); addArgs();
} }
public LocalResponseNormalization(@NonNull INDArray input, @NonNull LocalResponseNormalizationConfig LocalResponseNormalizationConfig){
super(new INDArray[]{input}, null);
this.config = config;
addArgs();
}
@Override @Override
public Map<String, Object> propertiesForFunction() { public Map<String, Object> propertiesForFunction() {

View File

@ -60,14 +60,17 @@ public class MaxPooling2D extends DynamicCustomOp {
addArgs(); addArgs();
} }
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ public MaxPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, wrapOrNull(output)); super(null, new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.MAX); config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config; this.config = config;
addArgs(); addArgs();
} }
public MaxPooling2D(@NonNull INDArray input, @NonNull Pooling2DConfig pooling2DConfig) {
this(input, null, pooling2DConfig);
}
@Override @Override
public boolean isConfigProperties() { public boolean isConfigProperties() {
return true; return true;

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -38,9 +39,9 @@ import java.util.Map;
*/ */
@Slf4j @Slf4j
@Getter @Getter
@NoArgsConstructor
public class MaxPooling3D extends Pooling3D { public class MaxPooling3D extends Pooling3D {
public MaxPooling3D() {
}
public MaxPooling3D(SameDiff sameDiff, SDVariable input, Pooling3DConfig config) { public MaxPooling3D(SameDiff sameDiff, SDVariable input, Pooling3DConfig config) {
super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX); super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX);
@ -50,6 +51,10 @@ public class MaxPooling3D extends Pooling3D {
super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.MAX); super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.MAX);
} }
public MaxPooling3D(INDArray input, Pooling3DConfig pooling3DConfig) {
super(null, null, new INDArray[]{input},null, false, pooling3DConfig, Pooling3DType.MAX);
}
@Override @Override
public boolean isConfigProperties() { public boolean isConfigProperties() {
return true; return true;

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -37,6 +38,7 @@ import java.util.*;
* Pooling3D operation * Pooling3D operation
*/ */
@Slf4j @Slf4j
@NoArgsConstructor
public abstract class Pooling3D extends DynamicCustomOp { public abstract class Pooling3D extends DynamicCustomOp {
protected Pooling3DConfig config; protected Pooling3DConfig config;
@ -52,8 +54,6 @@ public abstract class Pooling3D extends DynamicCustomOp {
return super.iArgs(); return super.iArgs();
} }
public Pooling3D() {}
public Pooling3D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] inputArrays, INDArray[] outputs,boolean inPlace, public Pooling3D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] inputArrays, INDArray[] outputs,boolean inPlace,
Pooling3DConfig pooling3DConfig, Pooling3DType type) { Pooling3DConfig pooling3DConfig, Pooling3DType type) {
super(null,sameDiff, inputs, inPlace); super(null,sameDiff, inputs, inPlace);

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -33,6 +34,7 @@ import java.util.List;
* Pooling3DDerivative operation * Pooling3DDerivative operation
*/ */
@Slf4j @Slf4j
@NoArgsConstructor
public class Pooling3DDerivative extends Pooling3D { public class Pooling3DDerivative extends Pooling3D {
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
@ -41,9 +43,6 @@ public class Pooling3DDerivative extends Pooling3D {
super(sameDiff, inputs, inputArrays, outputs, inPlace, pooling3DConfig, type); super(sameDiff, inputs, inputArrays, outputs, inPlace, pooling3DConfig, type);
} }
public Pooling3DDerivative() {}
@Override @Override
public String opName() { public String opName() {

View File

@ -48,12 +48,18 @@ public class SConv2D extends Conv2D {
super(inputs, outputs, config); super(inputs, outputs, config);
} }
public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, @NonNull Conv2DConfig Conv2DConfig){
this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config); this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, Conv2DConfig);
}
public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, @NonNull Conv2DConfig Conv2DConfig){
this(wrapFilterNull(layerInput, depthWeights, pointWeights), null, Conv2DConfig);
} }
public SConv2D() {} public SConv2D() {}
@Override @Override
public String opName() { public String opName() {
return "sconv2d"; return "sconv2d";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -24,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.enums.DataFormat;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -65,6 +67,13 @@ public class SpaceToDepth extends DynamicCustomOp {
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }
public SpaceToDepth(@NonNull INDArray x, int blockSize, @NonNull DataFormat dataFormat) {
this(x, null, blockSize,dataFormat.toString());
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Gradient to SpaceToDepth is just DepthToSpace of same block size and data format. // Gradient to SpaceToDepth is just DepthToSpace of same block size and data format.

View File

@ -17,12 +17,15 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections; import java.util.Collections;
@ -34,6 +37,7 @@ import java.util.List;
*/ */
@Slf4j @Slf4j
@Getter @Getter
@NoArgsConstructor
public class Upsampling2d extends DynamicCustomOp { public class Upsampling2d extends DynamicCustomOp {
@ -53,7 +57,20 @@ public class Upsampling2d extends DynamicCustomOp {
} }
public Upsampling2d() {} public Upsampling2d(INDArray input, int scale) {
this(input, scale, scale, true);
}
public Upsampling2d(INDArray input, int scaleH, int scaleW, boolean nchw) {
super(new INDArray[]{input}, null);
this.nchw = nchw;
this.scaleH = scaleH;
this.scaleW = scaleW;
addIArgument(scaleH);
addIArgument(scaleW);
addIArgument(nchw ? 1 : 0);
}
@Override @Override

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
@ -46,6 +47,12 @@ public class GRUCell extends DynamicCustomOp {
this.weights = weights; this.weights = weights;
} }
public GRUCell(INDArray x, INDArray hLast, GRUWeights gruWeights) {
super(null, null, gruWeights.argsWithInputs(x, hLast));
this.weights = gruWeights;
}
@Override @Override
public String opName() { public String opName() {
return "gruCell"; return "gruCell";

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
@ -89,6 +90,15 @@ public class LSTMBlockCell extends DynamicCustomOp {
addTArgument(configuration.tArgs()); addTArgument(configuration.tArgs());
} }
public LSTMBlockCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) {
super(null, null, lstmWeights.argsWithInputs(x, cLast, yLast));
this.configuration = lstmConfiguration;
this.weights = lstmWeights;
addIArgument(configuration.iArgs(false));
addTArgument(configuration.tArgs());
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) { public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 8, "Expected exactly 8 inputs to LSTMBlockCell, got %s", inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 8, "Expected exactly 8 inputs to LSTMBlockCell, got %s", inputDataTypes);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.NoArgsConstructor;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -31,13 +32,11 @@ import java.util.Map;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@NoArgsConstructor
public class LSTMCell extends DynamicCustomOp { public class LSTMCell extends DynamicCustomOp {
private LSTMCellConfiguration configuration; private LSTMCellConfiguration configuration;
public LSTMCell() {
}
public LSTMCell(SameDiff sameDiff, LSTMCellConfiguration configuration) { public LSTMCell(SameDiff sameDiff, LSTMCellConfiguration configuration) {
super(null, sameDiff, configuration.args()); super(null, sameDiff, configuration.args());
this.configuration = configuration; this.configuration = configuration;
@ -66,16 +65,4 @@ public class LSTMCell extends DynamicCustomOp {
public String tensorflowName() { public String tensorflowName() {
return super.tensorflowName(); return super.tensorflowName();
} }
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
} }

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
@ -91,6 +92,14 @@ public class LSTMLayer extends DynamicCustomOp {
addTArgument(configuration.tArgs()); addTArgument(configuration.tArgs());
} }
public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) {
super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
this.configuration = lstmConfiguration;
this.weights = lstmWeights;
addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) { public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMLayer, got %s", inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMLayer, got %s", inputDataTypes);

View File

@ -16,14 +16,14 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import java.util.Arrays;
import java.util.List;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -37,6 +37,7 @@ import java.util.Map;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@NoArgsConstructor
public class SRU extends DynamicCustomOp { public class SRU extends DynamicCustomOp {
@Getter @Getter
@ -45,14 +46,23 @@ public class SRU extends DynamicCustomOp {
@Getter @Getter
private SDVariable mask; private SDVariable mask;
public SRU() { }
public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask)); super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask));
this.mask = mask; this.mask = mask;
this.weights = weights; this.weights = weights;
} }
public SRU(INDArray x, INDArray initialC, INDArray mask, SRUWeights sruWeights) {
super(wrapFilterNull(x, sruWeights.getIWeights(), sruWeights.getIBias(), initialC, mask), null);
this.mask = (SDVariable) mask;
this.weights = sruWeights;
}
public SRU(INDArray x, INDArray initialC, SRUWeights sruWeights) {
super(wrapFilterNull(x, sruWeights.getIWeights(), sruWeights.getIBias(), initialC), null);
this.weights = sruWeights;
}
@Override @Override
public String opName() { public String opName() {
return "sru"; return "sru";

View File

@ -22,6 +22,7 @@ import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -46,6 +47,14 @@ public class SRUCell extends DynamicCustomOp {
this.weights = weights; this.weights = weights;
} }
public SRUCell(INDArray x, INDArray cLast, SRUWeights sruWeights) {
super(null, null, sruWeights.argsWithInputs(x, cLast));
this.weights = sruWeights;
}
@Override @Override
public String opName() { public String opName() {
return "sruCell"; return "sruCell";

View File

@ -5,6 +5,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
/** /**
@ -21,31 +22,36 @@ public class GRUWeights extends RNNWeights {
* *
* The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset. * The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset.
*/ */
@NonNull
private SDVariable ruWeight; private SDVariable ruWeight;
private INDArray iRuWeights;
/** /**
* Cell gate weights, with a shape of [inSize + numUnits, numUnits] * Cell gate weights, with a shape of [inSize + numUnits, numUnits]
*/ */
@NonNull
private SDVariable cWeight; private SDVariable cWeight;
private INDArray iCWeight;
/** /**
* Reset and Update gate bias, with a shape of [2*numUnits]. May be null. * Reset and Update gate bias, with a shape of [2*numUnits]. May be null.
* *
* The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset. * The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset.
*/ */
@NonNull
private SDVariable ruBias; private SDVariable ruBias;
private INDArray iRUBias;
/** /**
* Cell gate bias, with a shape of [numUnits]. May be null. * Cell gate bias, with a shape of [numUnits]. May be null.
*/ */
@NonNull
private SDVariable cBias; private SDVariable cBias;
private INDArray iCBias;
@Override @Override
public SDVariable[] args() { public SDVariable[] args() {
return filterNonNull(ruWeight, cWeight, ruBias, cBias); return filterNonNull(ruWeight, cWeight, ruBias, cBias);
} }
@Override
public INDArray[] arrayArgs() {
return filterNonNull(iRuWeights, iCWeight, iRUBias, iCBias);
}
} }

View File

@ -5,6 +5,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
@ -23,35 +24,40 @@ public class LSTMWeights extends RNNWeights {
* Input to hidden and hidden to hidden are concatenated in dimension 0, * Input to hidden and hidden to hidden are concatenated in dimension 0,
* so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :]. * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :].
*/ */
@NonNull
private SDVariable weights; private SDVariable weights;
private INDArray iWeights;
/** /**
* Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits]. * Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits].
*/ */
@NonNull
private SDVariable inputPeepholeWeights; private SDVariable inputPeepholeWeights;
private INDArray iInputPeepholeWeights;
/** /**
* Cell peephole (t-1) connections to forget gate, with a shape of [numUnits]. * Cell peephole (t-1) connections to forget gate, with a shape of [numUnits].
*/ */
@NonNull
private SDVariable forgetPeepholeWeights; private SDVariable forgetPeepholeWeights;
private INDArray iForgetPeepholeWeights;
/** /**
* Cell peephole (t) connections to output gate, with a shape of [numUnits]. * Cell peephole (t) connections to output gate, with a shape of [numUnits].
*/ */
@NonNull
private SDVariable outputPeepholeWeights; private SDVariable outputPeepholeWeights;
private INDArray iOutputPeepholeWeights;
/** /**
* Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits]. * Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits].
*/ */
@NonNull
private SDVariable bias; private SDVariable bias;
private INDArray iBias;
@Override @Override
public SDVariable[] args() { public SDVariable[] args() {
return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias); return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias);
} }
@Override
public INDArray[] arrayArgs() {
return filterNonNull(iWeights, iInputPeepholeWeights, iForgetPeepholeWeights, iOutputPeepholeWeights, iBias);
}
} }

View File

@ -1,35 +1,38 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
import java.lang.reflect.Array;
import java.util.Arrays; import java.util.Arrays;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
public abstract class RNNWeights { public abstract class RNNWeights {
public abstract SDVariable[] args(); public abstract SDVariable[] args();
protected static SDVariable[] filterNonNull(SDVariable... args){ public abstract INDArray[] arrayArgs();
protected static <T> T[] filterNonNull(T... args){
int count = 0; int count = 0;
for(SDVariable v : args){ for( int i=0; i<args.length; i++ ) {
if(v != null){ if (args[i] != null) count++;
count++; }
T[] out = (T[]) Array.newInstance(args.getClass().getComponentType(), count);
int j=0;
for( int i=0; i<args.length; i++ ){
if(args[i] != null){
out[j++] = args[i];
} }
} }
return out;
SDVariable[] res = new SDVariable[count];
int i = 0;
for(SDVariable v : args){
if(v != null){
res[i] = v;
i++;
}
}
return res;
} }
public SDVariable[] argsWithInputs(SDVariable... inputs){ public SDVariable[] argsWithInputs(SDVariable... inputs){
return ArrayUtil.combine(inputs, args()); return ArrayUtil.combine(inputs, args());
} }
public INDArray[] argsWithInputs(INDArray... inputs) {
return ArrayUtil.combine(inputs, arrayArgs());
}
} }

View File

@ -5,6 +5,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU; import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
@ -21,17 +22,24 @@ public class SRUWeights extends RNNWeights {
/** /**
* Weights, with shape [inSize, 3*inSize]. * Weights, with shape [inSize, 3*inSize].
*/ */
@NonNull
private SDVariable weights; private SDVariable weights;
private INDArray iWeights;
/** /**
* Biases, with shape [2*inSize]. * Biases, with shape [2*inSize].
*/ */
@NonNull
private SDVariable bias; private SDVariable bias;
private INDArray iBias;
@Override @Override
public SDVariable[] args() { public SDVariable[] args() {
return new SDVariable[]{weights, bias}; return new SDVariable[]{weights, bias};
} }
@Override
public INDArray[] arrayArgs() {
return new INDArray[]{iWeights, iBias};
}
} }

View File

@ -21,6 +21,7 @@ import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -62,6 +63,18 @@ public class BatchToSpace extends DynamicCustomOp {
addIArgument(b); addIArgument(b);
} }
public BatchToSpace(INDArray x, int[] blocks, int[] croppingTop, int[] croppingBottom) {
super(null,x,null,null,null);
this.blocks = blocks;
this.crops = new int[][]{croppingTop,croppingBottom};
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < crops.length; e++)
addIArgument(crops[e][0], crops[e][1]);
}
@Override @Override
public String opName() { public String opName() {
return "batch_to_space"; return "batch_to_space";

View File

@ -16,6 +16,8 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -40,6 +42,7 @@ import java.util.*;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@NoArgsConstructor
public class Dilation2D extends DynamicCustomOp { public class Dilation2D extends DynamicCustomOp {
protected boolean isSameMode; protected boolean isSameMode;
@ -49,18 +52,11 @@ public class Dilation2D extends DynamicCustomOp {
// strides // strides
protected int s0, s1, s2, s3; protected int s0, s1, s2, s3;
public Dilation2D() {
}
public Dilation2D(SameDiff sameDiff, SDVariable[] inputAndWeights, int[] strides, public Dilation2D(SameDiff sameDiff, SDVariable[] inputAndWeights, int[] strides,
int[] rates, boolean isSameMode, boolean inPlace ) { int[] rates, boolean isSameMode, boolean inPlace ) {
super(null, sameDiff, inputAndWeights, inPlace); super(null, sameDiff, inputAndWeights, inPlace);
Preconditions.checkArgument(rates.length == 4, "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates);
if (rates.length < 4) Preconditions.checkArgument(strides.length == 4, "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides);
throw new IllegalArgumentException("Dilation rate length must be 4.");
if (strides.length < 4)
throw new IllegalArgumentException("Strides length must be 4.");
r0 = rates[0]; r0 = rates[0];
r1 = rates[1]; r1 = rates[1];
@ -73,14 +69,31 @@ public class Dilation2D extends DynamicCustomOp {
this.isSameMode = isSameMode; this.isSameMode = isSameMode;
addArgs(); addArgs();
} }
public Dilation2D(INDArray[] inputArrays, INDArray[] outputs) { public Dilation2D(INDArray[] inputArrays, INDArray[] outputs) {
super(null, inputArrays, outputs); super(null, inputArrays, outputs);
} }
public Dilation2D(@NonNull INDArray df, @NonNull INDArray weights, @NonNull int[] strides, @NonNull int[] rates, boolean isSameMode) {
super(null, new INDArray[]{df, weights},null);
Preconditions.checkArgument(rates.length == 4, "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates);
Preconditions.checkArgument(strides.length == 4, "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides);
this.isSameMode = isSameMode;
r0 = rates[0];
r1 = rates[1];
r2 = rates[2];
r3 = rates[3];
s0 = strides[0];
s1 = strides[1];
s2 = strides[2];
s3 = strides[3];
addArgs();
}
protected void addArgs() { protected void addArgs() {
addIArgument(isSameMode ? 1 : 0, addIArgument(isSameMode ? 1 : 0,
r0, r1, r2, r3, r0, r1, r2, r3,

View File

@ -21,6 +21,7 @@ import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -62,6 +63,19 @@ public class SpaceToBatch extends DynamicCustomOp {
addIArgument(blocks[0]); addIArgument(blocks[0]);
} }
public SpaceToBatch(INDArray x, int[] blocks, int[] paddingTop, int[] paddingBottom) {
super(null,x,null,null,null);
this.blocks = blocks;
this.padding = new int[][]{paddingTop,paddingBottom};
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < padding.length; e++)
addIArgument(padding[e][0], padding[e][1]);
}
@Override @Override
public String opName() { public String opName() {
return "space_to_batch"; return "space_to_batch";

View File

@ -136,6 +136,21 @@ public class Nd4j {
*/ */
public static final NDLoss loss = new NDLoss(); public static final NDLoss loss = new NDLoss();
/**
* Convolutional network namespace - operations related to convolutional neural networks
*/
public static final NDCNN cnn = new NDCNN();
/**
* Recurrent neural network namespace - operations related to recurrent neural networks
*/
public static final NDRNN rnn = new NDRNN();
/**
* Image namespace - operations related to images
*/
public static final NDImage image = new NDImage();
/** /**
* Bitwise namespace - operations related to bitwise manipulation of arrays * Bitwise namespace - operations related to bitwise manipulation of arrays
*/ */
@ -169,6 +184,27 @@ public class Nd4j {
*/ */
public static NDLoss loss(){ return loss; } public static NDLoss loss(){ return loss; }
/**
* Convolutional network namespace - operations related to convolutional neural networks
*/
public static NDCNN cnn(){
return cnn;
}
/**
* Recurrent neural network namespace - operations related to recurrent neural networks
*/
public static NDRNN rnn(){
return rnn;
}
/**
* Image namespace - operations related to images
*/
public static NDImage image(){
return image;
}
private final static String DATA_BUFFER_OPS = "databufferfactory"; private final static String DATA_BUFFER_OPS = "databufferfactory";
private final static String CONVOLUTION_OPS = "convops"; private final static String CONVOLUTION_OPS = "convops";
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/ /**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.enums;
/**
* Data format: "NCHW" or "NHWC" */
public enum DataFormat {
NCHW,
NHWC
}

View File

@ -0,0 +1,499 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.enums.DataFormat;
public class NDCNN {
public NDCNN() {
}
/**
* 2D Convolution layer operation - average pooling 2d<br>
*
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying average pooling on the input (NUMERIC type)
*/
public INDArray avgPooling2d(INDArray input, Pooling2DConfig Pooling2DConfig) {
NDValidation.validateNumerical("avgPooling2d", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D(input, Pooling2DConfig))[0];
}
/**
* 3D convolution layer operation - average pooling 3d <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output after applying average pooling on the input (NUMERIC type)
*/
public INDArray avgPooling3d(INDArray input, Pooling3DConfig Pooling3DConfig) {
NDValidation.validateNumerical("avgPooling3d", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D(input, Pooling3DConfig))[0];
}
/**
* Convolution 2d layer batch to space operation on 4d input.<br>
* Reduces input batch dimension by rearranging data into a larger spatial dimensions<br>
*
* @param x Input variable. 4d input (NUMERIC type)
* @param blocks Block size, in the height/width dimension (Size: Exactly(count=2))
* @param croppingTop (Size: Exactly(count=2))
* @param croppingBottom (Size: Exactly(count=2))
* @return output Output variable (NUMERIC type)
*/
public INDArray batchToSpace(INDArray x, int[] blocks, int[] croppingTop, int... croppingBottom) {
NDValidation.validateNumerical("batchToSpace", "x", x);
Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length);
Preconditions.checkArgument(croppingTop.length == 2, "croppingTop has incorrect size/length. Expected: croppingTop.length == 2, got %s", croppingTop.length);
Preconditions.checkArgument(croppingBottom.length == 2, "croppingBottom has incorrect size/length. Expected: croppingBottom.length == 2, got %s", croppingBottom.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace(x, blocks, croppingTop, croppingBottom))[0];
}
/**
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape<br>
* [minibatch, inputChannels, height, width]<br>
*
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output Col2Im output variable (NUMERIC type)
*/
public INDArray col2Im(INDArray in, Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("col2Im", "in", in);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im(in, Conv2DConfig))[0];
}
/**
* Conv1d operation.<br>
*
* @param input the inputs to conv1d (NUMERIC type)
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type)
* @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv1DConfig Configuration Object
* @return output result of conv1d op (NUMERIC type)
*/
public INDArray conv1d(INDArray input, INDArray weights, INDArray bias,
Conv1DConfig Conv1DConfig) {
NDValidation.validateNumerical("conv1d", "input", input);
NDValidation.validateNumerical("conv1d", "weights", weights);
NDValidation.validateNumerical("conv1d", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(input, weights, bias, Conv1DConfig))[0];
}
/**
* Conv1d operation.<br>
*
* @param input the inputs to conv1d (NUMERIC type)
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type)
* @param Conv1DConfig Configuration Object
* @return output result of conv1d op (NUMERIC type)
*/
public INDArray conv1d(INDArray input, INDArray weights, Conv1DConfig Conv1DConfig) {
NDValidation.validateNumerical("conv1d", "input", input);
NDValidation.validateNumerical("conv1d", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(input, weights, null, Conv1DConfig))[0];
}
/**
* 2D Convolution operation with optional bias<br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type)
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output result of conv2d op (NUMERIC type)
*/
public INDArray conv2d(INDArray layerInput, INDArray weights, INDArray bias,
Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("conv2d", "layerInput", layerInput);
NDValidation.validateNumerical("conv2d", "weights", weights);
NDValidation.validateNumerical("conv2d", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(layerInput, weights, bias, Conv2DConfig))[0];
}
/**
* 2D Convolution operation with optional bias<br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type)
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output result of conv2d op (NUMERIC type)
*/
public INDArray conv2d(INDArray layerInput, INDArray weights, Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("conv2d", "layerInput", layerInput);
NDValidation.validateNumerical("conv2d", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(layerInput, weights, null, Conv2DConfig))[0];
}
/**
* Convolution 3D operation with optional bias <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv3DConfig Configuration Object
* @return output Conv3d output variable (NUMERIC type)
*/
public INDArray conv3d(INDArray input, INDArray weights, INDArray bias,
Conv3DConfig Conv3DConfig) {
NDValidation.validateNumerical("conv3d", "input", input);
NDValidation.validateNumerical("conv3d", "weights", weights);
NDValidation.validateNumerical("conv3d", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(input, weights, bias, Conv3DConfig))[0];
}
/**
* Convolution 3D operation with optional bias <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param Conv3DConfig Configuration Object
* @return output Conv3d output variable (NUMERIC type)
*/
public INDArray conv3d(INDArray input, INDArray weights, Conv3DConfig Conv3DConfig) {
NDValidation.validateNumerical("conv3d", "input", input);
NDValidation.validateNumerical("conv3d", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(input, weights, null, Conv3DConfig))[0];
}
/**
* 2D deconvolution operation with optional bias<br>
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param DeConv2DConfig Configuration Object
* @return output result of deconv2d op (NUMERIC type)
*/
public INDArray deconv2d(INDArray layerInput, INDArray weights, INDArray bias,
DeConv2DConfig DeConv2DConfig) {
NDValidation.validateNumerical("deconv2d", "layerInput", layerInput);
NDValidation.validateNumerical("deconv2d", "weights", weights);
NDValidation.validateNumerical("deconv2d", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(layerInput, weights, bias, DeConv2DConfig))[0];
}
/**
* 2D deconvolution operation with optional bias<br>
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param DeConv2DConfig Configuration Object
* @return output result of deconv2d op (NUMERIC type)
*/
public INDArray deconv2d(INDArray layerInput, INDArray weights, DeConv2DConfig DeConv2DConfig) {
NDValidation.validateNumerical("deconv2d", "layerInput", layerInput);
NDValidation.validateNumerical("deconv2d", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(layerInput, weights, null, DeConv2DConfig))[0];
}
/**
* 3D CNN deconvolution operation with or without optional bias<br>
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type)
* @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type)
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] (NUMERIC type)
* @param DeConv3DConfig Configuration Object
* @return output result of 3D CNN deconvolution operation (NUMERIC type)
*/
public INDArray deconv3d(INDArray input, INDArray weights, INDArray bias,
DeConv3DConfig DeConv3DConfig) {
NDValidation.validateNumerical("deconv3d", "input", input);
NDValidation.validateNumerical("deconv3d", "weights", weights);
NDValidation.validateNumerical("deconv3d", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(input, weights, bias, DeConv3DConfig))[0];
}
/**
* 3D CNN deconvolution operation with or without optional bias<br>
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type)
* @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type)
* @param DeConv3DConfig Configuration Object
* @return output result of 3D CNN deconvolution operation (NUMERIC type)
*/
public INDArray deconv3d(INDArray input, INDArray weights, DeConv3DConfig DeConv3DConfig) {
NDValidation.validateNumerical("deconv3d", "input", input);
NDValidation.validateNumerical("deconv3d", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(input, weights, null, DeConv3DConfig))[0];
}
/**
* Convolution 2d layer batch to space operation on 4d input.<br>
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
* = [mb, 2, 4, 4]<br>
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)
*/
public INDArray depthToSpace(INDArray x, int blockSize, DataFormat dataFormat) {
NDValidation.validateNumerical("depthToSpace", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace(x, blockSize, dataFormat))[0];
}
/**
* Depth-wise 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type)
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output result of depthwise conv2d op (NUMERIC type)
*/
public INDArray depthWiseConv2d(INDArray layerInput, INDArray depthWeights, INDArray bias,
Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput);
NDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights);
NDValidation.validateNumerical("depthWiseConv2d", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(layerInput, depthWeights, bias, Conv2DConfig))[0];
}
/**
* Depth-wise 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type)
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output result of depthwise conv2d op (NUMERIC type)
*/
public INDArray depthWiseConv2d(INDArray layerInput, INDArray depthWeights,
Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput);
NDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(layerInput, depthWeights, null, Conv2DConfig))[0];
}
/**
* TODO doc string<br>
*
* @param df (NUMERIC type)
* @param weights df (NUMERIC type)
* @param strides weights (Size: Exactly(count=2))
* @param rates strides (Size: Exactly(count=2))
* @param isSameMode isSameMode
* @return output Computed the grayscale dilation of 4-D input and 3-D filters tensors. (NUMERIC type)
*/
public INDArray dilation2D(INDArray df, INDArray weights, int[] strides, int[] rates,
boolean isSameMode) {
NDValidation.validateNumerical("dilation2D", "df", df);
NDValidation.validateNumerical("dilation2D", "weights", weights);
Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length);
Preconditions.checkArgument(rates.length == 2, "rates has incorrect size/length. Expected: rates.length == 2, got %s", rates.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D(df, weights, strides, rates, isSameMode))[0];
}
/**
* Extract image patches <br>
*
* @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] (NUMERIC type)
* @param kH Kernel height
* @param kW Kernel width
* @param sH Stride height
* @param sW Stride width
* @param rH Rate height
* @param rW Rate width
* @param sameMode If true: use same mode padding. If false
* @return output The result is a 4D tensor which is indexed by batch, row, and column. (NUMERIC type)
*/
public INDArray extractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH,
int rW, boolean sameMode) {
NDValidation.validateNumerical("extractImagePatches", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode))[0];
}
/**
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape<br>
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] <br>
*
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output Im2Col output variable (NUMERIC type)
*/
public INDArray im2Col(INDArray in, Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("im2Col", "in", in);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col(in, Conv2DConfig))[0];
}
/**
* 2D convolution layer operation - local response normalization<br>
*
* @param input the inputs to lrn (NUMERIC type)
* @param LocalResponseNormalizationConfig Configuration Object
* @return output Result after Local Response Normalization (NUMERIC type)
*/
public INDArray localResponseNormalization(INDArray input,
LocalResponseNormalizationConfig LocalResponseNormalizationConfig) {
NDValidation.validateNumerical("localResponseNormalization", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0];
}
/**
* 2D Convolution layer operation - max pooling 2d <br>
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
public INDArray maxPooling2d(INDArray input, Pooling2DConfig Pooling2DConfig) {
NDValidation.validateNumerical("maxPooling2d", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D(input, Pooling2DConfig))[0];
}
/**
* 3D convolution layer operation - max pooling 3d operation.<br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
public INDArray maxPooling3d(INDArray input, Pooling3DConfig Pooling3DConfig) {
NDValidation.validateNumerical("maxPooling3d", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D(input, Pooling3DConfig))[0];
}
/**
* Separable 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output result of separable convolution 2d operation (NUMERIC type)
*/
public INDArray separableConv2d(INDArray layerInput, INDArray depthWeights, INDArray pointWeights,
INDArray bias, Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("separableConv2d", "layerInput", layerInput);
NDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights);
NDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights);
NDValidation.validateNumerical("separableConv2d", "bias", bias);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(layerInput, depthWeights, pointWeights, bias, Conv2DConfig))[0];
}
/**
* Separable 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param Conv2DConfig Configuration Object
* @return output result of separable convolution 2d operation (NUMERIC type)
*/
public INDArray separableConv2d(INDArray layerInput, INDArray depthWeights, INDArray pointWeights,
Conv2DConfig Conv2DConfig) {
NDValidation.validateNumerical("separableConv2d", "layerInput", layerInput);
NDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights);
NDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(layerInput, depthWeights, pointWeights, null, Conv2DConfig))[0];
}
/**
* Convolution 2d layer space to batch operation on 4d input.<br>
* Increases input batch dimension by rearranging data from spatial dimensions into batch dimension <br>
*
* @param x Input variable. 4d input (NUMERIC type)
* @param blocks Block size, in the height/width dimension (Size: Exactly(count=2))
* @param paddingTop Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2))
* @param paddingBottom Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2))
* @return output Output variable (NUMERIC type)
*/
public INDArray spaceToBatch(INDArray x, int[] blocks, int[] paddingTop, int... paddingBottom) {
NDValidation.validateNumerical("spaceToBatch", "x", x);
Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length);
Preconditions.checkArgument(paddingTop.length == 2, "paddingTop has incorrect size/length. Expected: paddingTop.length == 2, got %s", paddingTop.length);
Preconditions.checkArgument(paddingBottom.length == 2, "paddingBottom has incorrect size/length. Expected: paddingBottom.length == 2, got %s", paddingBottom.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch(x, blocks, paddingTop, paddingBottom))[0];
}
/**
* Convolution 2d layer space to depth operation on 4d input.<br>
* Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension<br>
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
* = [mb, 2, 4, 4] <br>
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)
*/
public INDArray spaceToDepth(INDArray x, int blockSize, DataFormat dataFormat) {
NDValidation.validateNumerical("spaceToDepth", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth(x, blockSize, dataFormat))[0];
}
/**
* Upsampling layer for 2D inputs.<br>
* scale is used for both height and width dimensions. <br>
*
* @param input Input in NCHW format (NUMERIC type)
* @param scale The scale for both height and width dimensions.
* @return output Upsampled input (NUMERIC type)
*/
public INDArray upsampling2d(INDArray input, int scale) {
NDValidation.validateNumerical("upsampling2d", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(input, scale))[0];
}
/**
* 2D Convolution layer operation - Upsampling 2d <br>
*
* @param input Input in NCHW format (NUMERIC type)
* @param scaleH Scale to upsample in height dimension
* @param scaleW Scale to upsample in width dimension
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
* @return output Upsampled input (NUMERIC type)
*/
public INDArray upsampling2d(INDArray input, int scaleH, int scaleW, boolean nchw) {
NDValidation.validateNumerical("upsampling2d", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(input, scaleH, scaleW, nchw))[0];
}
}

View File

@ -0,0 +1,221 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;
public class NDImage {
public NDImage() {
}
/**
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
*
* @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
* @return output Cropped and resized images (NUMERIC type)
*/
public INDArray cropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices,
INDArray cropOutSize, double extrapolationValue) {
NDValidation.validateNumerical("CropAndResize", "image", image);
NDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
NDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
NDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.CropAndResize(image, cropBoxes, boxIndices, cropOutSize, extrapolationValue))[0];
}
/**
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
*
* @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
* @return output Cropped and resized images (NUMERIC type)
*/
public INDArray cropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices,
INDArray cropOutSize) {
NDValidation.validateNumerical("CropAndResize", "image", image);
NDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
NDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
NDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.CropAndResize(image, cropBoxes, boxIndices, cropOutSize, 0.0))[0];
}
/**
* Adjusts contrast of RGB or grayscale images.<br>
*
* @param in images to adjust. 3D shape or higher (NUMERIC type)
* @param factor multiplier for adjusting contrast
* @return output Contrast-adjusted image (NUMERIC type)
*/
public INDArray adjustContrast(INDArray in, double factor) {
NDValidation.validateNumerical("adjustContrast", "in", in);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.AdjustContrast(in, factor))[0];
}
/**
* Adjust hue of RGB image <br>
*
* @param in image as 3D array (NUMERIC type)
* @param delta value to add to hue channel
* @return output adjusted image (NUMERIC type)
*/
public INDArray adjustHue(INDArray in, double delta) {
NDValidation.validateNumerical("adjustHue", "in", in);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.AdjustHue(in, delta))[0];
}
/**
* Adjust saturation of RGB images<br>
*
* @param in RGB image as 3D array (NUMERIC type)
* @param factor factor for saturation
* @return output adjusted image (NUMERIC type)
*/
public INDArray adjustSaturation(INDArray in, double factor) {
NDValidation.validateNumerical("adjustSaturation", "in", in);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.AdjustSaturation(in, factor))[0];
}
/**
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. <br>
*
* @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type)
* @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2))
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2))
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0))
* @param sameMode Padding algorithm. If true: use Same padding
* @return output The extracted image patches (NUMERIC type)
*/
public INDArray extractImagePatches(INDArray image, int[] kSizes, int[] strides, int[] rates,
boolean sameMode) {
NDValidation.validateNumerical("extractImagePatches", "image", image);
Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length);
Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length);
Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(image, kSizes, strides, rates, sameMode))[0];
}
/**
* Converting image from HSV to RGB format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public INDArray hsvToRgb(INDArray input) {
NDValidation.validateNumerical("hsvToRgb", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.HsvToRgb(input))[0];
}
/**
* Greedily selects a subset of bounding boxes in descending order of score<br>
*
* @param boxes Might be null. Name for the output variable (NUMERIC type)
* @param scores vector of shape [num_boxes] (NUMERIC type)
* @param maxOutSize scalar representing the maximum number of boxes to be selected
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
* @param scoreThreshold threshold for deciding when to remove boxes based on score
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type)
*/
public INDArray nonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize,
double iouThreshold, double scoreThreshold) {
NDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes);
NDValidation.validateNumerical("nonMaxSuppression", "scores", scores);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(boxes, scores, maxOutSize, iouThreshold, scoreThreshold))[0];
}
/**
* Randomly crops image<br>
*
* @param input input array (NUMERIC type)
* @param shape shape for crop (INT type)
* @return output cropped array (NUMERIC type)
*/
public INDArray randomCrop(INDArray input, INDArray shape) {
NDValidation.validateNumerical("randomCrop", "input", input);
NDValidation.validateInteger("randomCrop", "shape", shape);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RandomCrop(input, shape))[0];
}
/**
* Converting array from HSV to RGB format<br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public INDArray rgbToHsv(INDArray input) {
NDValidation.validateNumerical("rgbToHsv", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RgbToHsv(input))[0];
}
/**
* Converting array from RGB to YIQ format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public INDArray rgbToYiq(INDArray input) {
NDValidation.validateNumerical("rgbToYiq", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RgbToYiq(input))[0];
}
/**
* Converting array from RGB to YUV format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public INDArray rgbToYuv(INDArray input) {
NDValidation.validateNumerical("rgbToYuv", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.RgbToYuv(input))[0];
}
/**
* Converting image from YIQ to RGB format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public INDArray yiqToRgb(INDArray input) {
NDValidation.validateNumerical("yiqToRgb", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.YiqToRgb(input))[0];
}
/**
* Converting image from YUV to RGB format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public INDArray yuvToRgb(INDArray input) {
NDValidation.validateNumerical("yuvToRgb", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.YuvToRgb(input))[0];
}
}

View File

@ -0,0 +1,130 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;
public class NDRNN {
public NDRNN() {
}
/**
* The GRU cell. Does a single time step operation<br>
*
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
* @param GRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/
public INDArray gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) {
NDValidation.validateNumerical("gru", "x", x);
NDValidation.validateNumerical("gru", "hLast", hLast);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights))[0];
}
/**
* The LSTM cell. Does a single time step operation.<br>
*
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object
* @return output The cell's outputs (NUMERIC type)
*/
public INDArray lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights,
LSTMConfiguration LSTMConfiguration) {
NDValidation.validateNumerical("lstmCell", "x", x);
NDValidation.validateNumerical("lstmCell", "cLast", cLast);
NDValidation.validateNumerical("lstmCell", "yLast", yLast);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0];
}
/**
* The LSTM layer. Does multiple time steps.<br>
*
* @param maxTSLength (NUMERIC type)
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type)
*/
public INDArray lstmLayer(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast,
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
NDValidation.validateNumerical("lstmLayer", "x", x);
NDValidation.validateNumerical("lstmLayer", "cLast", cLast);
NDValidation.validateNumerical("lstmLayer", "yLast", yLast);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0];
}
/**
* The SRU layer. Does a single time step operation.<br>
*
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
* @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type)
* @param SRUWeights Configuration Object
* @return output The cell's outputs.. (NUMERIC type)
*/
public INDArray sru(INDArray x, INDArray initialC, INDArray mask, SRUWeights SRUWeights) {
NDValidation.validateNumerical("sru", "x", x);
NDValidation.validateNumerical("sru", "initialC", initialC);
NDValidation.validateNumerical("sru", "mask", mask);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(x, initialC, mask, SRUWeights))[0];
}
/**
* The SRU layer. Does a single time step operation.<br>
*
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
* @param SRUWeights Configuration Object
* @return output The cell's outputs.. (NUMERIC type)
*/
public INDArray sru(INDArray x, INDArray initialC, SRUWeights SRUWeights) {
NDValidation.validateNumerical("sru", "x", x);
NDValidation.validateNumerical("sru", "initialC", initialC);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(x, initialC, null, SRUWeights))[0];
}
/**
* The SRU layer. Does a single time step operation.<br>
*
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type)
* @param SRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/
public INDArray sruCell(INDArray x, INDArray cLast, SRUWeights SRUWeights) {
NDValidation.validateNumerical("sruCell", "x", x);
NDValidation.validateNumerical("sruCell", "cLast", cLast);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(x, cLast, SRUWeights))[0];
}
}