SDCNN cleanup pass (#230)

* SDCNN cleanup

Signed-off-by: Ryan Nett <rnett@skymind.io>

* NonNull annotations

Signed-off-by: Ryan Nett <rnett@skymind.io>

* better javadoc, NonNull fix for sconv

Signed-off-by: Ryan Nett <rnett@skymind.io>

* update builders to fix names

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* even more fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix for null bias

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-09-04 00:44:01 -07:00 committed by GitHub
parent 6cc887bee9
commit e9454b8882
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 466 additions and 451 deletions

View File

@ -469,7 +469,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) { public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
LocalResponseNormalization lrn = LocalResponseNormalization.builder() LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder()
.inputFunctions(new SDVariable[]{input}) .inputFunctions(new SDVariable[]{input})
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(lrnConfig) .config(lrnConfig)
@ -487,7 +487,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
Conv1D conv1D = Conv1D.builder() Conv1D conv1D = Conv1D.sameDiffBuilder()
.inputFunctions(new SDVariable[]{input, weights}) .inputFunctions(new SDVariable[]{input, weights})
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(conv1DConfig) .config(conv1DConfig)
@ -496,6 +496,34 @@ public class DifferentialFunctionFactory {
return conv1D.outputVariable(); return conv1D.outputVariable();
} }
/**
* Conv1d operation.
*
* @param input the inputs to conv1d
* @param weights conv1d weights
* @param bias conv1d bias
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) {
SDVariable[] args;
if(bias == null){
args = new SDVariable[]{input, weights};
} else {
args = new SDVariable[]{input, weights, bias};
}
Conv1D conv1D = Conv1D.sameDiffBuilder()
.inputFunctions(args)
.sameDiff(sameDiff())
.config(conv1DConfig)
.build();
return conv1D.outputVariable();
}
/** /**
* Conv2d operation. * Conv2d operation.
* *
@ -504,7 +532,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
Conv2D conv2D = Conv2D.builder() Conv2D conv2D = Conv2D.sameDiffBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(conv2DConfig) .config(conv2DConfig)
@ -530,7 +558,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
AvgPooling2D avgPooling2D = AvgPooling2D.builder() AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder()
.input(input) .input(input)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(pooling2DConfig) .config(pooling2DConfig)
@ -547,7 +575,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
MaxPooling2D maxPooling2D = MaxPooling2D.builder() MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder()
.input(input) .input(input)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(pooling2DConfig) .config(pooling2DConfig)
@ -590,7 +618,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
SConv2D sconv2D = SConv2D.sBuilder() SConv2D sconv2D = SConv2D.sameDiffSBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.conv2DConfig(conv2DConfig) .conv2DConfig(conv2DConfig)
@ -609,7 +637,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
SConv2D depthWiseConv2D = SConv2D.sBuilder() SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.conv2DConfig(depthConv2DConfig) .conv2DConfig(depthConv2DConfig)
@ -627,7 +655,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
DeConv2D deconv2D = DeConv2D.builder() DeConv2D deconv2D = DeConv2D.sameDiffBuilder()
.inputs(inputs) .inputs(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(deconv2DConfig) .config(deconv2DConfig)
@ -654,9 +682,9 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) { public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
Conv3D conv3D = Conv3D.builder() Conv3D conv3D = Conv3D.sameDiffBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.conv3DConfig(conv3DConfig) .config(conv3DConfig)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.build(); .build();

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
@ -38,14 +39,9 @@ public class SDCNN extends SDOps {
} }
/** /**
* 2D Convolution layer operation - average pooling 2d * See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}.
*
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration for
* @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
return avgPooling2d(null, input, pooling2DConfig); return avgPooling2d(null, input, pooling2DConfig);
} }
@ -58,22 +54,16 @@ public class SDCNN extends SDOps {
* @param pooling2DConfig the configuration * @param pooling2DConfig the configuration
* @return Result after applying average pooling on the input * @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
validateFloatingPoint("avgPooling2d", input); validateFloatingPoint("avgPooling2d", input);
SDVariable ret = f().avgPooling2d(input, pooling2DConfig); SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 3D convolution layer operation - average pooling 3d * See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
return avgPooling3d(null, input, pooling3DConfig); return avgPooling3d(null, input, pooling3DConfig);
} }
@ -87,7 +77,7 @@ public class SDCNN extends SDOps {
* @param pooling3DConfig the configuration * @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input * @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
validateFloatingPoint("avgPooling3d", input); validateFloatingPoint("avgPooling3d", input);
SDVariable ret = f().avgPooling3d(input, pooling3DConfig); SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -96,7 +86,7 @@ public class SDCNN extends SDOps {
/** /**
* @see #batchToSpace(String, SDVariable, int[], int[][]) * @see #batchToSpace(String, SDVariable, int[], int[][])
*/ */
public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) { public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
return batchToSpace(null, x, blocks, crops); return batchToSpace(null, x, blocks, crops);
} }
@ -111,7 +101,7 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #spaceToBatch(String, SDVariable, int[], int[][]) * @see #spaceToBatch(String, SDVariable, int[], int[][])
*/ */
public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) { public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
validateNumerical("batchToSpace", x); validateNumerical("batchToSpace", x);
SDVariable ret = f().batchToSpace(x, blocks, crops); SDVariable ret = f().batchToSpace(x, blocks, crops);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -119,14 +109,9 @@ public class SDCNN extends SDOps {
/** /**
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape * See {@link #col2Im(String, SDVariable, Conv2DConfig)}.
* [minibatch, inputChannels, height, width]
*
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
* @param config Convolution configuration for the col2im operation
* @return Col2Im output variable
*/ */
public SDVariable col2Im(SDVariable in, Conv2DConfig config) { public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
return col2Im(null, in, config); return col2Im(null, in, config);
} }
@ -139,33 +124,22 @@ public class SDCNN extends SDOps {
* @param config Convolution configuration for the col2im operation * @param config Convolution configuration for the col2im operation
* @return Col2Im output variable * @return Col2Im output variable
*/ */
public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) { public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
SDVariable ret = f().col2Im(in, config); SDVariable ret = f().col2Im(in, config);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 1D Convolution layer operation - Conv1d * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
*
* @param input the input array/activations for the conv1d op
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
* @param conv1DConfig the configuration
* @return
*/ */
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
return conv1d(null, input, weights, conv1DConfig); return conv1d((String) null, input, weights, conv1DConfig);
} }
/** /**
* Conv1d operation. * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
*
* @param name name of the operation in SameDiff
* @param input the inputs to conv1d
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
* @param conv1DConfig the configuration
* @return
*/ */
public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
validateFloatingPoint("conv1d", input); validateFloatingPoint("conv1d", input);
validateFloatingPoint("conv1d", weights); validateFloatingPoint("conv1d", weights);
SDVariable ret = f().conv1d(input, weights, conv1DConfig); SDVariable ret = f().conv1d(input, weights, conv1DConfig);
@ -173,21 +147,55 @@ public class SDCNN extends SDOps {
} }
/** /**
* 2D Convolution operation (without bias) * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}.
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/ */
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) { public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
return conv1d(null, input, weights, bias, conv1DConfig);
}
/**
* Conv1d operation.
*
* @param name name of the operation in SameDiff
* @param input the inputs to conv1d
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels]
* @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null.
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
validateFloatingPoint("conv1d", input);
validateFloatingPoint("conv1d", weights);
validateFloatingPoint("conv1d", bias);
SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
return conv2d(layerInput, weights, null, config); return conv2d(layerInput, weights, null, config);
} }
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
return conv2d(name, layerInput, weights, null, config);
}
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
return conv2d(null, layerInput, weights, bias, config);
}
/** /**
* 2D Convolution operation with optional bias * 2D Convolution operation with optional bias
* *
* @param name name of the operation in SameDiff
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
@ -195,7 +203,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of conv2d op * @return result of conv2d op
*/ */
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) { public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("conv2d", "input", layerInput); validateFloatingPoint("conv2d", "input", layerInput);
validateFloatingPoint("conv2d", "weights", weights); validateFloatingPoint("conv2d", "weights", weights);
validateFloatingPoint("conv2d", "bias", bias); validateFloatingPoint("conv2d", "bias", bias);
@ -204,18 +212,13 @@ public class SDCNN extends SDOps {
arr[1] = weights; arr[1] = weights;
if (bias != null) if (bias != null)
arr[2] = bias; arr[2] = bias;
return conv2d(arr, config); return conv2d(name, arr, config);
} }
/** /**
* 2D Convolution operation with optional bias * See {@link #conv2d(String, SDVariable[], Conv2DConfig)}.
*
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param config Conv2DConfig configuration
* @return result of convolution 2d operation
*/ */
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) { public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
return conv2d(null, inputs, config); return conv2d(null, inputs, config);
} }
@ -228,7 +231,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of convolution 2d operation * @return result of convolution 2d operation
*/ */
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) { public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateNumerical("conv2d", v); validateNumerical("conv2d", v);
SDVariable ret = f().conv2d(inputs, config); SDVariable ret = f().conv2d(inputs, config);
@ -236,19 +239,26 @@ public class SDCNN extends SDOps {
} }
/** /**
* Convolution 3D operation without bias * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/ */
public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, null, conv3DConfig); return conv3d(null, input, weights, null, conv3DConfig);
} }
/**
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
*/
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(name, input, weights, null, conv3DConfig);
}
/**
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}.
*/
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, bias, conv3DConfig);
}
/** /**
* Convolution 3D operation with optional bias * Convolution 3D operation with optional bias
* *
@ -261,7 +271,7 @@ public class SDCNN extends SDOps {
* @param conv3DConfig the configuration * @param conv3DConfig the configuration
* @return Conv3d output variable * @return Conv3d output variable
*/ */
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
validateFloatingPoint("conv3d", "input", input); validateFloatingPoint("conv3d", "input", input);
validateFloatingPoint("conv3d", "weights", weights); validateFloatingPoint("conv3d", "weights", weights);
validateFloatingPoint("conv3d", "bias", bias); validateFloatingPoint("conv3d", "bias", bias);
@ -276,51 +286,30 @@ public class SDCNN extends SDOps {
} }
/** /**
* Convolution 3D operation with optional bias * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/ */
public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
return conv3d(null, input, weights, bias, conv3DConfig);
}
/**
* Convolution 3D operation without bias
*
* @param name Name of the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
return conv3d(name, input, weights, null, conv3DConfig);
}
/**
* 2D deconvolution operation without bias
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
* @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) {
return deconv2d(layerInput, weights, null, deconv2DConfig); return deconv2d(layerInput, weights, null, deconv2DConfig);
} }
/**
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
*/
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(name, layerInput, weights, null, deconv2DConfig);
}
/**
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}.
*/
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(null, layerInput, weights, bias, deconv2DConfig);
}
/** /**
* 2D deconvolution operation with optional bias * 2D deconvolution operation with optional bias
* *
* @param name name of the operation in SameDiff
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
@ -328,7 +317,7 @@ public class SDCNN extends SDOps {
* @param deconv2DConfig DeConv2DConfig configuration * @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op * @return result of deconv2d op
*/ */
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
validateFloatingPoint("deconv2d", "input", layerInput); validateFloatingPoint("deconv2d", "input", layerInput);
validateFloatingPoint("deconv2d", "weights", weights); validateFloatingPoint("deconv2d", "weights", weights);
validateFloatingPoint("deconv2d", "bias", bias); validateFloatingPoint("deconv2d", "bias", bias);
@ -337,18 +326,13 @@ public class SDCNN extends SDOps {
arr[1] = weights; arr[1] = weights;
if (bias != null) if (bias != null)
arr[2] = bias; arr[2] = bias;
return deconv2d(arr, deconv2DConfig); return deconv2d(name, arr, deconv2DConfig);
} }
/** /**
* 2D deconvolution operation with or without optional bias * See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}.
*
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
* @param deconv2DConfig the configuration
* @return result of deconv2d op
*/ */
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(null, inputs, deconv2DConfig); return deconv2d(null, inputs, deconv2DConfig);
} }
@ -361,13 +345,34 @@ public class SDCNN extends SDOps {
* @param deconv2DConfig the configuration * @param deconv2DConfig the configuration
* @return result of deconv2d op * @return result of deconv2d op
*/ */
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateNumerical("deconv2d", v); validateNumerical("deconv2d", v);
SDVariable ret = f().deconv2d(inputs, deconv2DConfig); SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
*/
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
return deconv3d(input, weights, null, config);
}
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
*/
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
return deconv3d(name, input, weights, null, config);
}
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}.
*/
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
return deconv3d(null, input, weights, bias, config);
}
/** /**
* 3D CNN deconvolution operation with or without optional bias * 3D CNN deconvolution operation with or without optional bias
* *
@ -377,7 +382,7 @@ public class SDCNN extends SDOps {
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration * @param config Configuration
*/ */
public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
validateFloatingPoint("conv3d", input); validateFloatingPoint("conv3d", input);
validateFloatingPoint("conv3d", weights); validateFloatingPoint("conv3d", weights);
validateFloatingPoint("conv3d", bias); validateFloatingPoint("conv3d", bias);
@ -386,41 +391,9 @@ public class SDCNN extends SDOps {
} }
/** /**
* 3D CNN deconvolution operation with or without optional bias * See {@link #depthToSpace(String, SDVariable, int, String)}.
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration
*/ */
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
return deconv3d(null, input, weights, bias, config);
}
/**
* 3D CNN deconvolution operation with no bias
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param config Configuration
*/
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
return deconv3d(input, weights, null, config);
}
/**
* Convolution 2d layer batch to space operation on 4d input.<br>
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
* = [mb, 2, 4, 4]
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return Output variable
*/
public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) {
return depthToSpace(null, x, blockSize, dataFormat); return depthToSpace(null, x, blockSize, dataFormat);
} }
@ -438,27 +411,36 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #depthToSpace(String, SDVariable, int, String) * @see #depthToSpace(String, SDVariable, int, String)
*/ */
public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) { public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat); SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* Depth-wise 2D convolution operation without bias * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/ */
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) { public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
return depthWiseConv2d(layerInput, depthWeights, null, config); return depthWiseConv2d(layerInput, depthWeights, null, config);
} }
/**
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
return depthWiseConv2d(name, layerInput, depthWeights, null, config);
}
/**
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
return depthWiseConv2d(null, layerInput, depthWeights, bias, config);
}
/** /**
* Depth-wise 2D convolution operation with optional bias * Depth-wise 2D convolution operation with optional bias
* *
* @param name name of the operation in SameDiff
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
@ -466,7 +448,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of depthwise conv2d op * @return result of depthwise conv2d op
*/ */
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) { public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("depthwiseConv2d", "input", layerInput); validateFloatingPoint("depthwiseConv2d", "input", layerInput);
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights); validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
validateFloatingPoint("depthwiseConv2d", "bias", bias); validateFloatingPoint("depthwiseConv2d", "bias", bias);
@ -475,19 +457,13 @@ public class SDCNN extends SDOps {
arr[1] = depthWeights; arr[1] = depthWeights;
if (bias != null) if (bias != null)
arr[2] = bias; arr[2] = bias;
return depthWiseConv2d(arr, config); return depthWiseConv2d(name, arr, config);
} }
/** /**
* Depth-wise convolution 2D operation. * See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}.
*
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
* or 3 elements (layerInput, depthWeights, bias) as described in
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op
*/ */
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
return depthWiseConv2d(null, inputs, depthConv2DConfig); return depthWiseConv2d(null, inputs, depthConv2DConfig);
} }
@ -501,7 +477,7 @@ public class SDCNN extends SDOps {
* @param depthConv2DConfig the configuration * @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op * @return result of depthwise conv2d op
*/ */
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateFloatingPoint("depthWiseConv2d", v); validateFloatingPoint("depthWiseConv2d", v);
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
@ -509,17 +485,10 @@ public class SDCNN extends SDOps {
} }
/** /**
* TODO doc string * See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}.
*
* @param df
* @param weights
* @param strides
* @param rates
* @param isSameMode
* @return
*/ */
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
int[] rates, boolean isSameMode) { @NonNull int[] rates, @NonNull boolean isSameMode) {
return dilation2D(null, df, weights, strides, rates, isSameMode); return dilation2D(null, df, weights, strides, rates, isSameMode);
} }
@ -534,8 +503,8 @@ public class SDCNN extends SDOps {
* @param isSameMode * @param isSameMode
* @return * @return
*/ */
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
int[] rates, boolean isSameMode) { @NonNull int[] rates, @NonNull boolean isSameMode) {
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode); SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
@ -555,21 +524,16 @@ public class SDCNN extends SDOps {
* @param sameMode If true: use same mode padding. If false * @param sameMode If true: use same mode padding. If false
* @return * @return
*/ */
public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode); SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape * See {@link #im2Col(String, SDVariable, Conv2DConfig)}.
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
*
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
* @param config Convolution configuration for the im2col operation
* @return Im2Col output variable
*/ */
public SDVariable im2Col(SDVariable in, Conv2DConfig config) { public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
return im2Col(null, in, config); return im2Col(null, in, config);
} }
@ -582,20 +546,16 @@ public class SDCNN extends SDOps {
* @param config Convolution configuration for the im2col operation * @param config Convolution configuration for the im2col operation
* @return Im2Col output variable * @return Im2Col output variable
*/ */
public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) { public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
SDVariable ret = f().im2Col(in, config); SDVariable ret = f().im2Col(in, config);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 2D convolution layer operation - local response normalization * See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}.
*
* @param inputs the inputs to lrn
* @param lrnConfig the configuration
* @return
*/ */
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) { public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) {
return localResponseNormalization(null, inputs, lrnConfig); return localResponseNormalization(null, inputs, lrnConfig);
} }
@ -607,8 +567,8 @@ public class SDCNN extends SDOps {
* @param lrnConfig the configuration * @param lrnConfig the configuration
* @return * @return
*/ */
public SDVariable localResponseNormalization(String name, SDVariable input, public SDVariable localResponseNormalization(String name, @NonNull SDVariable input,
LocalResponseNormalizationConfig lrnConfig) { @NonNull LocalResponseNormalizationConfig lrnConfig) {
validateFloatingPoint("local response normalization", input); validateFloatingPoint("local response normalization", input);
SDVariable ret = f().localResponseNormalization(input, lrnConfig); SDVariable ret = f().localResponseNormalization(input, lrnConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -616,14 +576,9 @@ public class SDCNN extends SDOps {
/** /**
* 2D Convolution layer operation - max pooling 2d * See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}.
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
return maxPooling2d(null, input, pooling2DConfig); return maxPooling2d(null, input, pooling2DConfig);
} }
@ -636,22 +591,16 @@ public class SDCNN extends SDOps {
* @param pooling2DConfig the configuration * @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input * @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
validateNumerical("maxPooling2d", input); validateNumerical("maxPooling2d", input);
SDVariable ret = f().maxPooling2d(input, pooling2DConfig); SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 3D convolution layer operation - max pooling 3d operation. * See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
return maxPooling3d(null, input, pooling3DConfig); return maxPooling3d(null, input, pooling3DConfig);
} }
@ -665,7 +614,7 @@ public class SDCNN extends SDOps {
* @param pooling3DConfig the configuration * @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input * @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
validateNumerical("maxPooling3d", input); validateNumerical("maxPooling3d", input);
SDVariable ret = f().maxPooling3d(input, pooling3DConfig); SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -673,21 +622,30 @@ public class SDCNN extends SDOps {
/** /**
* Separable 2D convolution operation without bias * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]
* May be null
* @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation
*/ */
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
Conv2DConfig config) { @NonNull Conv2DConfig config) {
return separableConv2d(layerInput, depthWeights, pointWeights, null, config); return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
} }
/**
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
@NonNull Conv2DConfig config) {
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
}
/**
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, @NonNull Conv2DConfig config) {
return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config);
}
/** /**
* Separable 2D convolution operation with optional bias * Separable 2D convolution operation with optional bias
* *
@ -700,8 +658,8 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation * @return result of separable convolution 2d operation
*/ */
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, Conv2DConfig config) { SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("separableConv2d", "input", layerInput); validateFloatingPoint("separableConv2d", "input", layerInput);
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights); validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights); validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
@ -712,18 +670,13 @@ public class SDCNN extends SDOps {
arr[2] = pointWeights; arr[2] = pointWeights;
if (bias != null) if (bias != null)
arr[3] = bias; arr[3] = bias;
return sconv2d(arr, config); return sconv2d(name, arr, config);
} }
/** /**
* Separable 2D convolution operation with/without optional bias * See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}.
*
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param conv2DConfig the configuration
* @return result of separable convolution 2d operation
*/ */
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
return sconv2d(null, inputs, conv2DConfig); return sconv2d(null, inputs, conv2DConfig);
} }
@ -736,7 +689,7 @@ public class SDCNN extends SDOps {
* @param conv2DConfig the configuration * @param conv2DConfig the configuration
* @return result of separable convolution 2d operation * @return result of separable convolution 2d operation
*/ */
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateFloatingPoint("sconv2d", v); validateFloatingPoint("sconv2d", v);
SDVariable ret = f().sconv2d(inputs, conv2DConfig); SDVariable ret = f().sconv2d(inputs, conv2DConfig);
@ -747,7 +700,7 @@ public class SDCNN extends SDOps {
/** /**
* @see #spaceToBatch(String, SDVariable, int[], int[][]) * @see #spaceToBatch(String, SDVariable, int[], int[][])
*/ */
public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) { public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
return spaceToBatch(null, x, blocks, padding); return spaceToBatch(null, x, blocks, padding);
} }
@ -762,7 +715,7 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #batchToSpace(String, SDVariable, int[], int[][]) * @see #batchToSpace(String, SDVariable, int[], int[][])
*/ */
public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) { public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
SDVariable ret = f().spaceToBatch(x, blocks, padding); SDVariable ret = f().spaceToBatch(x, blocks, padding);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
@ -770,7 +723,7 @@ public class SDCNN extends SDOps {
/** /**
* @see #spaceToDepth(String, SDVariable, int, String) * @see #spaceToDepth(String, SDVariable, int, String)
*/ */
public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) { public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
return spaceToDepth(null, x, blockSize, dataFormat); return spaceToDepth(null, x, blockSize, dataFormat);
} }
@ -788,23 +741,39 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #depthToSpace(String, SDVariable, int, String) * @see #depthToSpace(String, SDVariable, int, String)
*/ */
public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) { public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat); SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. * See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
* scale is used for both height and width dimensions.
* *
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) * @param scale The scale for both height and width dimensions.
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
*/ */
public SDVariable upsampling2d(SDVariable input, int scale) { public SDVariable upsampling2d(@NonNull SDVariable input, int scale) {
return upsampling2d(null, input, true, scale, scale); return upsampling2d(null, input, true, scale, scale);
} }
/**
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
* scale is used for both height and width dimensions.
*
* @param scale The scale for both height and width dimensions.
*/
public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) {
return upsampling2d(name, input, true, scale, scale);
}
/**
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)}.
*/
public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
return upsampling2d(null, input, nchw, scaleH, scaleW);
}
/** /**
* 2D Convolution layer operation - Upsampling 2d * 2D Convolution layer operation - Upsampling 2d
* *
@ -814,33 +783,8 @@ public class SDCNN extends SDOps {
* @param scaleW Scale to upsample in width dimension * @param scaleW Scale to upsample in width dimension
* @return Upsampled input * @return Upsampled input
*/ */
public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) { public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW); SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
*/
public SDVariable upsampling2d(String name, SDVariable input, int scale) {
return upsampling2d(name, input, true, scale, scale);
}
/**
* 2D Convolution layer operation - Upsampling 2d
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* or NHWC format (shape [minibatch, height, width, channels])
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
* @param scaleH Scale to upsample in height dimension
* @param scaleW Scale to upsample in width dimension
* @return Upsampled input
*/
public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
return upsampling2d(null, input, nchw, scaleH, scaleW);
}
} }

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp {
} }
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
super(null, sameDiff, new SDVariable[]{input}, false); super(sameDiff, new SDVariable[]{input});
if (arrayInput != null) {
addInputArgument(arrayInput);
}
if (arrayOutput != null) {
addOutputArgument(arrayOutput);
}
config.setType(Pooling2D.Pooling2DType.AVG); config.setType(Pooling2D.Pooling2DType.AVG);
this.config = config;
addArgs();
}
public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.AVG);
this.sameDiff = sameDiff;
this.config = config; this.config = config;
addArgs(); addArgs();
} }

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp {
protected Conv1DConfig config; protected Conv1DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv1D(SameDiff sameDiff, public Conv1D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv1DConfig config) { Conv1DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputFunctions);
this.sameDiff = sameDiff; initConfig(config);
}
public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){
super(inputs, outputs);
initConfig(config);
}
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private void initConfig(Conv1DConfig config){
this.config = config; this.config = config;
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputFunctions, this);
} }
protected void addArgs() { protected void addArgs() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp {
protected Conv2DConfig config; protected Conv2DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s "; private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv2D(SameDiff sameDiff, public Conv2D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv2DConfig config) { Conv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputFunctions);
this.sameDiff = sameDiff;
initConfig(config);
}
public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs);
initConfig(config);
}
public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
protected void initConfig(Conv2DConfig config){
this.config = config; this.config = config;
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
INVALID_CONFIGURATION, INVALID_CONFIGURATION,
config.getSH(), config.getPH(), config.getDW()); config.getSH(), config.getPH(), config.getDW());
addArgs(); addArgs();
if(sameDiff != null) {
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
sameDiff.addArgsFor(inputFunctions, this);
}
} }
protected void addArgs() { protected void addArgs() {
@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp {
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder() Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
.sameDiff(sameDiff) .sameDiff(sameDiff)
.config(config) .config(config)
.outputs(outputArguments())
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.build(); .build();
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables()); List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());

View File

@ -37,8 +37,8 @@ import java.util.List;
public class Conv2DDerivative extends Conv2D { public class Conv2DDerivative extends Conv2D {
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) {
super(sameDiff, inputFunctions, inputArrays, outputs, config); super(sameDiff, inputFunctions, config);
} }
public Conv2DDerivative() {} public Conv2DDerivative() {}

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp {
public Conv3D() { public Conv3D() {
} }
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
Conv3DConfig conv3DConfig) { super(sameDiff, inputFunctions);
super(null, sameDiff, inputFunctions, false); initConfig(config);
setSameDiff(sameDiff); }
if (inputs != null) public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){
addInputArgument(inputs); super(inputs, outputs);
if (outputs != null) initConfig(config);
addOutputArgument(outputs); }
this.config = conv3DConfig;
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private void initConfig(Conv3DConfig config){
this.config = config;
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
INVALID_CONFIGURATION, INVALID_CONFIGURATION,
config.getSW(), config.getPH(), config.getDW()); config.getSW(), config.getPH(), config.getDW());
addArgs(); addArgs();
//for (val arg: iArgs())
// System.out.println(getIArgument(arg));
} }
@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp {
inputs.add(f1.get(0)); inputs.add(f1.get(0));
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder() Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
.conv3DConfig(config) .conv3DConfig(config)
.inputFunctions(args())
.outputs(outputArguments())
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.sameDiff(sameDiff) .sameDiff(sameDiff)
.build(); .build();

View File

@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D {
public Conv3DDerivative() {} public Conv3DDerivative() {}
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) { public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) {
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig); super(sameDiff, inputFunctions, conv3DConfig);
} }
@Override @Override

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp {
protected DeConv2DConfig config; protected DeConv2DConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public DeConv2D(SameDiff sameDiff, public DeConv2D(SameDiff sameDiff,
SDVariable[] inputs, SDVariable[] inputs,
INDArray[] inputArrays, INDArray[] outputs,
DeConv2DConfig config) { DeConv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputs);
this.sameDiff = sameDiff;
this.config = config; this.config = config;
if (inputArrays != null) { addArgs();
addInputArgument(inputArrays);
}
if (outputs != null) {
addOutputArgument(outputs);
} }
public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this); }
sameDiff.addArgsFor(inputs, this);
public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
@Override @Override

View File

@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D {
public DeConv2DDerivative() {} public DeConv2DDerivative() {}
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) {
super(sameDiff, inputs, inputArrays, outputs, config); super(sameDiff, inputs, config);
} }
@Override @Override

View File

@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp {
protected DeConv2DConfig config; protected DeConv2DConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public DeConv2DTF(SameDiff sameDiff, public DeConv2DTF(SameDiff sameDiff,
SDVariable[] inputs, SDVariable[] inputs,
INDArray[] inputArrays, INDArray[] outputs,
DeConv2DConfig config) { DeConv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputs);
this.sameDiff = sameDiff;
this.config = config;
addArgs();
}
public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
super(inputs, outputs);
this.config = config; this.config = config;
if (inputArrays != null) {
addInputArgument(inputArrays);
}
if (outputs != null) {
addOutputArgument(outputs);
}
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputs, this);
} }
@Override @Override

View File

@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp {
protected DeConv3DConfig config; protected DeConv3DConfig config;
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) { public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
super(sameDiff, toArr(input, weights, bias)); super(sameDiff, toArr(input, weights, bias));
this.config = config; this.config = config;
addArgs(); addArgs();
} }
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
if(bias != null){ if(bias != null){
return new SDVariable[]{input, weights, bias}; return new SDVariable[]{input, weights, bias};

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp {
protected Conv2DConfig config; protected Conv2DConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public DepthwiseConv2D(SameDiff sameDiff, public DepthwiseConv2D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv2DConfig config) { Conv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputFunctions);
this.sameDiff = sameDiff;
this.config = config; this.config = config;
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point }
sameDiff.addArgsFor(inputFunctions, this);
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DepthwiseConv2D() { public DepthwiseConv2D() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp {
protected LocalResponseNormalizationConfig config; protected LocalResponseNormalizationConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace,
INDArray[] inputs, INDArray[] outputs,boolean inPlace,
LocalResponseNormalizationConfig config) { LocalResponseNormalizationConfig config) {
super(null,sameDiff, inputFunctions, inPlace); super(null,sameDiff, inputFunctions, inPlace);
this.config = config; this.config = config;
if(inputs != null) { addArgs();
addInputArgument(inputs);
}
if(outputs!= null) {
addOutputArgument(outputs);
} }
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
this.config = config;
addArgs(); addArgs();
} }

View File

@ -33,8 +33,8 @@ import java.util.List;
@Slf4j @Slf4j
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization { public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) { public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) {
super(sameDiff, inputFunctions, inputs, outputs, inPlace, config); super(sameDiff, inputFunctions, inPlace, config);
} }
public LocalResponseNormalizationDerivative() {} public LocalResponseNormalizationDerivative() {}

View File

@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp {
public MaxPooling2D() { public MaxPooling2D() {
} }
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
@SuppressWarnings("Used in lombok") @SuppressWarnings("Used in lombok")
public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
super(null, sameDiff, new SDVariable[]{input}, false); super(null, sameDiff, new SDVariable[]{input}, false);
if (arrayInput != null) {
addInputArgument(arrayInput);
}
if (arrayOutput != null) {
addOutputArgument(arrayOutput);
}
config.setType(Pooling2D.Pooling2DType.MAX); config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config; this.config = config;
this.sameDiff = sameDiff;
addArgs(); addArgs();
} }
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output}); super(null, new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.MAX); config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config; this.config = config;

View File

@ -16,8 +16,14 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.*;
/** /**
* Pooling2D operation * Pooling2D operation
@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp {
public Pooling2D() {} public Pooling2D() {}
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
@SuppressWarnings("Used in lombok") @SuppressWarnings("Used in lombok")
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) { public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,
super(null,sameDiff, inputs, false); Pooling2DConfig config) {
if(arrayInputs != null) { super(null, sameDiff, inputs, false);
addInputArgument(arrayInputs);
}
if(arrayOutputs != null) {
addOutputArgument(arrayOutputs);
}
this.config = config; this.config = config;
addArgs();
}
public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
this.config = config;
addArgs(); addArgs();
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -36,8 +37,12 @@ import java.util.List;
@Slf4j @Slf4j
public class Pooling2DDerivative extends Pooling2D { public class Pooling2DDerivative extends Pooling2D {
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) { public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) {
super(sameDiff, inputs, arrayInputs, arrayOutputs, config); super(sameDiff, inputs, config);
}
public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){
super(new INDArray[]{input, grad}, wrapOrNull(output), config);
} }
public Pooling2DDerivative() {} public Pooling2DDerivative() {}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -39,9 +40,17 @@ import java.util.List;
@Slf4j @Slf4j
public class SConv2D extends Conv2D { public class SConv2D extends Conv2D {
@Builder(builderMethodName = "sBuilder") @Builder(builderMethodName = "sameDiffSBuilder")
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); super(sameDiff, inputFunctions, conv2DConfig);
}
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs, config);
}
public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config);
} }
public SConv2D() {} public SConv2D() {}

View File

@ -38,8 +38,8 @@ import java.util.List;
public class SConv2DDerivative extends SConv2D { public class SConv2DDerivative extends SConv2D {
@Builder(builderMethodName = "sDerviativeBuilder") @Builder(builderMethodName = "sDerviativeBuilder")
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); super(sameDiff, inputFunctions, conv2DConfig);
} }
public SConv2DDerivative() {} public SConv2DDerivative() {}

View File

@ -235,10 +235,7 @@ public class Convolution {
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor, int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
double extra, int virtualHeight, int virtualWidth, INDArray out) { double extra, int virtualHeight, int virtualWidth, INDArray out) {
Pooling2D pooling = Pooling2D.builder() Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder()
.arrayInputs(new INDArray[]{img})
.arrayOutputs(new INDArray[]{out})
.config(Pooling2DConfig.builder()
.dH(dh) .dH(dh)
.dW(dw) .dW(dw)
.extra(extra) .extra(extra)
@ -251,8 +248,7 @@ public class Convolution {
.sW(sx) .sW(sx)
.type(type) .type(type)
.divisor(divisor) .divisor(divisor)
.build()) .build());
.build();
Nd4j.getExecutioner().execAndReturn(pooling); Nd4j.getExecutioner().execAndReturn(pooling);
return out; return out;
} }

View File

@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
INDArray input = Nd4j.create(inSize); INDArray input = Nd4j.create(inSize);
AvgPooling2D avgPooling2D = AvgPooling2D.builder() AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
.arrayInput(input)
.config(conf)
.build();
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation {
//Test backprop: //Test backprop:
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf);
.arrayInputs(new INDArray[]{input, grad})
.config(conf)
.build();
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
assertEquals(1, outSizesBP.size()); assertEquals(1, outSizesBP.size());
@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
INDArray input = Nd4j.create(inSize); INDArray input = Nd4j.create(inSize);
AvgPooling2D avgPooling2D = AvgPooling2D.builder() AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
.arrayInput(input)
.config(conf)
.build();
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
assertEquals(1, outSizes.size()); assertEquals(1, outSizes.size());
@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation {
INDArray grad = Nd4j.create(exp); INDArray grad = Nd4j.create(exp);
//Test backprop: //Test backprop:
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf);
.arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output)
.arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input
.config(conf)
.build();
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
assertEquals(1, outSizesBP.size()); assertEquals(1, outSizesBP.size());
@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(false) .isSameMode(false)
.build(); .build();
SDVariable out = sd.cnn().conv2d(vars, c); SDVariable out = sd.cnn().conv2d("conv", vars, c);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("out", out);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = sd.execAndEndResult();