SDCNN cleanup pass (#230)
* SDCNN cleanup Signed-off-by: Ryan Nett <rnett@skymind.io> * NonNull annotations Signed-off-by: Ryan Nett <rnett@skymind.io> * better javadoc, NonNull fix for sconv Signed-off-by: Ryan Nett <rnett@skymind.io> * update builders to fix names Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * even more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * fix for null bias Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
6cc887bee9
commit
e9454b8882
|
@ -469,7 +469,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
|
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
|
||||||
LocalResponseNormalization lrn = LocalResponseNormalization.builder()
|
LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder()
|
||||||
.inputFunctions(new SDVariable[]{input})
|
.inputFunctions(new SDVariable[]{input})
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(lrnConfig)
|
.config(lrnConfig)
|
||||||
|
@ -487,7 +487,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
||||||
Conv1D conv1D = Conv1D.builder()
|
Conv1D conv1D = Conv1D.sameDiffBuilder()
|
||||||
.inputFunctions(new SDVariable[]{input, weights})
|
.inputFunctions(new SDVariable[]{input, weights})
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(conv1DConfig)
|
.config(conv1DConfig)
|
||||||
|
@ -496,6 +496,34 @@ public class DifferentialFunctionFactory {
|
||||||
return conv1D.outputVariable();
|
return conv1D.outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Conv1d operation.
|
||||||
|
*
|
||||||
|
* @param input the inputs to conv1d
|
||||||
|
* @param weights conv1d weights
|
||||||
|
* @param bias conv1d bias
|
||||||
|
* @param conv1DConfig the configuration
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) {
|
||||||
|
|
||||||
|
SDVariable[] args;
|
||||||
|
|
||||||
|
if(bias == null){
|
||||||
|
args = new SDVariable[]{input, weights};
|
||||||
|
} else {
|
||||||
|
args = new SDVariable[]{input, weights, bias};
|
||||||
|
}
|
||||||
|
|
||||||
|
Conv1D conv1D = Conv1D.sameDiffBuilder()
|
||||||
|
.inputFunctions(args)
|
||||||
|
.sameDiff(sameDiff())
|
||||||
|
.config(conv1DConfig)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return conv1D.outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Conv2d operation.
|
* Conv2d operation.
|
||||||
*
|
*
|
||||||
|
@ -504,7 +532,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||||
Conv2D conv2D = Conv2D.builder()
|
Conv2D conv2D = Conv2D.sameDiffBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(conv2DConfig)
|
.config(conv2DConfig)
|
||||||
|
@ -530,7 +558,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder()
|
||||||
.input(input)
|
.input(input)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(pooling2DConfig)
|
.config(pooling2DConfig)
|
||||||
|
@ -547,7 +575,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||||
MaxPooling2D maxPooling2D = MaxPooling2D.builder()
|
MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder()
|
||||||
.input(input)
|
.input(input)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(pooling2DConfig)
|
.config(pooling2DConfig)
|
||||||
|
@ -590,7 +618,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||||
SConv2D sconv2D = SConv2D.sBuilder()
|
SConv2D sconv2D = SConv2D.sameDiffSBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.conv2DConfig(conv2DConfig)
|
.conv2DConfig(conv2DConfig)
|
||||||
|
@ -609,7 +637,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
||||||
SConv2D depthWiseConv2D = SConv2D.sBuilder()
|
SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.conv2DConfig(depthConv2DConfig)
|
.conv2DConfig(depthConv2DConfig)
|
||||||
|
@ -627,7 +655,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
||||||
DeConv2D deconv2D = DeConv2D.builder()
|
DeConv2D deconv2D = DeConv2D.sameDiffBuilder()
|
||||||
.inputs(inputs)
|
.inputs(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(deconv2DConfig)
|
.config(deconv2DConfig)
|
||||||
|
@ -654,9 +682,9 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
|
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
|
||||||
Conv3D conv3D = Conv3D.builder()
|
Conv3D conv3D = Conv3D.sameDiffBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.conv3DConfig(conv3DConfig)
|
.config(conv3DConfig)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||||
|
@ -38,14 +39,9 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - average pooling 2d
|
* See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param pooling2DConfig the configuration for
|
|
||||||
* @return Result after applying average pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
return avgPooling2d(null, input, pooling2DConfig);
|
return avgPooling2d(null, input, pooling2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,22 +54,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling2DConfig the configuration
|
* @param pooling2DConfig the configuration
|
||||||
* @return Result after applying average pooling on the input
|
* @return Result after applying average pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
validateFloatingPoint("avgPooling2d", input);
|
validateFloatingPoint("avgPooling2d", input);
|
||||||
SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
|
SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - average pooling 3d
|
* See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param pooling3DConfig the configuration
|
|
||||||
* @return Result after applying average pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
return avgPooling3d(null, input, pooling3DConfig);
|
return avgPooling3d(null, input, pooling3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +77,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling3DConfig the configuration
|
* @param pooling3DConfig the configuration
|
||||||
* @return Result after applying average pooling on the input
|
* @return Result after applying average pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
validateFloatingPoint("avgPooling3d", input);
|
validateFloatingPoint("avgPooling3d", input);
|
||||||
SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
|
SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -96,7 +86,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) {
|
public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
|
||||||
return batchToSpace(null, x, blocks, crops);
|
return batchToSpace(null, x, blocks, crops);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +101,7 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) {
|
public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
|
||||||
validateNumerical("batchToSpace", x);
|
validateNumerical("batchToSpace", x);
|
||||||
SDVariable ret = f().batchToSpace(x, blocks, crops);
|
SDVariable ret = f().batchToSpace(x, blocks, crops);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -119,14 +109,9 @@ public class SDCNN extends SDOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
|
* See {@link #col2Im(String, SDVariable, Conv2DConfig)}.
|
||||||
* [minibatch, inputChannels, height, width]
|
|
||||||
*
|
|
||||||
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
|
|
||||||
* @param config Convolution configuration for the col2im operation
|
|
||||||
* @return Col2Im output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable col2Im(SDVariable in, Conv2DConfig config) {
|
public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
return col2Im(null, in, config);
|
return col2Im(null, in, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,33 +124,22 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Convolution configuration for the col2im operation
|
* @param config Convolution configuration for the col2im operation
|
||||||
* @return Col2Im output variable
|
* @return Col2Im output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) {
|
public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
SDVariable ret = f().col2Im(in, config);
|
SDVariable ret = f().col2Im(in, config);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 1D Convolution layer operation - Conv1d
|
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param input the input array/activations for the conv1d op
|
|
||||||
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
|
|
||||||
* @param conv1DConfig the configuration
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
return conv1d(null, input, weights, conv1DConfig);
|
return conv1d((String) null, input, weights, conv1DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Conv1d operation.
|
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param name name of the operation in SameDiff
|
|
||||||
* @param input the inputs to conv1d
|
|
||||||
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
|
|
||||||
* @param conv1DConfig the configuration
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
validateFloatingPoint("conv1d", input);
|
validateFloatingPoint("conv1d", input);
|
||||||
validateFloatingPoint("conv1d", weights);
|
validateFloatingPoint("conv1d", weights);
|
||||||
SDVariable ret = f().conv1d(input, weights, conv1DConfig);
|
SDVariable ret = f().conv1d(input, weights, conv1DConfig);
|
||||||
|
@ -173,21 +147,55 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution operation (without bias)
|
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}.
|
||||||
*
|
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of conv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) {
|
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
|
return conv1d(null, input, weights, bias, conv1DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Conv1d operation.
|
||||||
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
|
* @param input the inputs to conv1d
|
||||||
|
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels]
|
||||||
|
* @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null.
|
||||||
|
* @param conv1DConfig the configuration
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
|
validateFloatingPoint("conv1d", input);
|
||||||
|
validateFloatingPoint("conv1d", weights);
|
||||||
|
validateFloatingPoint("conv1d", bias);
|
||||||
|
SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
|
||||||
return conv2d(layerInput, weights, null, config);
|
return conv2d(layerInput, weights, null, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
|
||||||
|
return conv2d(name, layerInput, weights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
|
return conv2d(null, layerInput, weights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution operation with optional bias
|
* 2D Convolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||||
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
||||||
|
@ -195,7 +203,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of conv2d op
|
* @return result of conv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) {
|
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
validateFloatingPoint("conv2d", "input", layerInput);
|
validateFloatingPoint("conv2d", "input", layerInput);
|
||||||
validateFloatingPoint("conv2d", "weights", weights);
|
validateFloatingPoint("conv2d", "weights", weights);
|
||||||
validateFloatingPoint("conv2d", "bias", bias);
|
validateFloatingPoint("conv2d", "bias", bias);
|
||||||
|
@ -204,18 +212,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[1] = weights;
|
arr[1] = weights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[2] = bias;
|
arr[2] = bias;
|
||||||
return conv2d(arr, config);
|
return conv2d(name, arr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution operation with optional bias
|
* See {@link #conv2d(String, SDVariable[], Conv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
|
|
||||||
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of convolution 2d operation
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) {
|
public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
|
||||||
return conv2d(null, inputs, config);
|
return conv2d(null, inputs, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,7 +231,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of convolution 2d operation
|
* @return result of convolution 2d operation
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) {
|
public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateNumerical("conv2d", v);
|
validateNumerical("conv2d", v);
|
||||||
SDVariable ret = f().conv2d(inputs, config);
|
SDVariable ret = f().conv2d(inputs, config);
|
||||||
|
@ -236,19 +239,26 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation without bias
|
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
|
||||||
* @param conv3DConfig the configuration
|
|
||||||
* @return Conv3d output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
|
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
return conv3d(null, input, weights, null, conv3DConfig);
|
return conv3d(null, input, weights, null, conv3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
|
return conv3d(name, input, weights, null, conv3DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
|
return conv3d(null, input, weights, bias, conv3DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias
|
* Convolution 3D operation with optional bias
|
||||||
*
|
*
|
||||||
|
@ -261,7 +271,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param conv3DConfig the configuration
|
* @param conv3DConfig the configuration
|
||||||
* @return Conv3d output variable
|
* @return Conv3d output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
|
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
validateFloatingPoint("conv3d", "input", input);
|
validateFloatingPoint("conv3d", "input", input);
|
||||||
validateFloatingPoint("conv3d", "weights", weights);
|
validateFloatingPoint("conv3d", "weights", weights);
|
||||||
validateFloatingPoint("conv3d", "bias", bias);
|
validateFloatingPoint("conv3d", "bias", bias);
|
||||||
|
@ -276,51 +286,30 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias
|
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
|
|
||||||
* @param conv3DConfig the configuration
|
|
||||||
* @return Conv3d output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
|
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
return conv3d(null, input, weights, bias, conv3DConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convolution 3D operation without bias
|
|
||||||
*
|
|
||||||
* @param name Name of the output variable
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
|
||||||
* @param conv3DConfig the configuration
|
|
||||||
* @return Conv3d output variable
|
|
||||||
*/
|
|
||||||
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
|
|
||||||
return conv3d(name, input, weights, null, conv3DConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 2D deconvolution operation without bias
|
|
||||||
*
|
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
|
||||||
* @param deconv2DConfig DeConv2DConfig configuration
|
|
||||||
* @return result of deconv2d op
|
|
||||||
*/
|
|
||||||
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) {
|
|
||||||
return deconv2d(layerInput, weights, null, deconv2DConfig);
|
return deconv2d(layerInput, weights, null, deconv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
|
return deconv2d(name, layerInput, weights, null, deconv2DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
|
return deconv2d(null, layerInput, weights, bias, deconv2DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with optional bias
|
* 2D deconvolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
||||||
|
@ -328,7 +317,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param deconv2DConfig DeConv2DConfig configuration
|
* @param deconv2DConfig DeConv2DConfig configuration
|
||||||
* @return result of deconv2d op
|
* @return result of deconv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
validateFloatingPoint("deconv2d", "input", layerInput);
|
validateFloatingPoint("deconv2d", "input", layerInput);
|
||||||
validateFloatingPoint("deconv2d", "weights", weights);
|
validateFloatingPoint("deconv2d", "weights", weights);
|
||||||
validateFloatingPoint("deconv2d", "bias", bias);
|
validateFloatingPoint("deconv2d", "bias", bias);
|
||||||
|
@ -337,18 +326,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[1] = weights;
|
arr[1] = weights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[2] = bias;
|
arr[2] = bias;
|
||||||
return deconv2d(arr, deconv2DConfig);
|
return deconv2d(name, arr, deconv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with or without optional bias
|
* See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
|
|
||||||
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
|
|
||||||
* @param deconv2DConfig the configuration
|
|
||||||
* @return result of deconv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
return deconv2d(null, inputs, deconv2DConfig);
|
return deconv2d(null, inputs, deconv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -361,13 +345,34 @@ public class SDCNN extends SDOps {
|
||||||
* @param deconv2DConfig the configuration
|
* @param deconv2DConfig the configuration
|
||||||
* @return result of deconv2d op
|
* @return result of deconv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateNumerical("deconv2d", v);
|
validateNumerical("deconv2d", v);
|
||||||
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
|
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
|
||||||
|
return deconv3d(input, weights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
|
||||||
|
return deconv3d(name, input, weights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||||
|
return deconv3d(null, input, weights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D CNN deconvolution operation with or without optional bias
|
* 3D CNN deconvolution operation with or without optional bias
|
||||||
*
|
*
|
||||||
|
@ -377,7 +382,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
||||||
* @param config Configuration
|
* @param config Configuration
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||||
validateFloatingPoint("conv3d", input);
|
validateFloatingPoint("conv3d", input);
|
||||||
validateFloatingPoint("conv3d", weights);
|
validateFloatingPoint("conv3d", weights);
|
||||||
validateFloatingPoint("conv3d", bias);
|
validateFloatingPoint("conv3d", bias);
|
||||||
|
@ -386,41 +391,9 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D CNN deconvolution operation with or without optional bias
|
* See {@link #depthToSpace(String, SDVariable, int, String)}.
|
||||||
*
|
|
||||||
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
|
||||||
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
|
||||||
* @param config Configuration
|
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
|
||||||
return deconv3d(null, input, weights, bias, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 3D CNN deconvolution operation with no bias
|
|
||||||
*
|
|
||||||
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
|
||||||
* @param config Configuration
|
|
||||||
*/
|
|
||||||
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
|
|
||||||
return deconv3d(input, weights, null, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convolution 2d layer batch to space operation on 4d input.<br>
|
|
||||||
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
|
|
||||||
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
|
|
||||||
* = [mb, 2, 4, 4]
|
|
||||||
*
|
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) {
|
|
||||||
return depthToSpace(null, x, blockSize, dataFormat);
|
return depthToSpace(null, x, blockSize, dataFormat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -438,27 +411,36 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #depthToSpace(String, SDVariable, int, String)
|
* @see #depthToSpace(String, SDVariable, int, String)
|
||||||
*/
|
*/
|
||||||
public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) {
|
public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
|
||||||
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
|
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Depth-wise 2D convolution operation without bias
|
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of conv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) {
|
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
|
||||||
return depthWiseConv2d(layerInput, depthWeights, null, config);
|
return depthWiseConv2d(layerInput, depthWeights, null, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
|
||||||
|
return depthWiseConv2d(name, layerInput, depthWeights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
|
return depthWiseConv2d(null, layerInput, depthWeights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Depth-wise 2D convolution operation with optional bias
|
* Depth-wise 2D convolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||||
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
||||||
|
@ -466,7 +448,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of depthwise conv2d op
|
* @return result of depthwise conv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) {
|
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
validateFloatingPoint("depthwiseConv2d", "input", layerInput);
|
validateFloatingPoint("depthwiseConv2d", "input", layerInput);
|
||||||
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
|
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
|
||||||
validateFloatingPoint("depthwiseConv2d", "bias", bias);
|
validateFloatingPoint("depthwiseConv2d", "bias", bias);
|
||||||
|
@ -475,19 +457,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[1] = depthWeights;
|
arr[1] = depthWeights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[2] = bias;
|
arr[2] = bias;
|
||||||
return depthWiseConv2d(arr, config);
|
return depthWiseConv2d(name, arr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Depth-wise convolution 2D operation.
|
* See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
|
|
||||||
* or 3 elements (layerInput, depthWeights, bias) as described in
|
|
||||||
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
|
||||||
* @param depthConv2DConfig the configuration
|
|
||||||
* @return result of depthwise conv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
|
||||||
return depthWiseConv2d(null, inputs, depthConv2DConfig);
|
return depthWiseConv2d(null, inputs, depthConv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -501,7 +477,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param depthConv2DConfig the configuration
|
* @param depthConv2DConfig the configuration
|
||||||
* @return result of depthwise conv2d op
|
* @return result of depthwise conv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateFloatingPoint("depthWiseConv2d", v);
|
validateFloatingPoint("depthWiseConv2d", v);
|
||||||
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
|
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
|
||||||
|
@ -509,17 +485,10 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO doc string
|
* See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}.
|
||||||
*
|
|
||||||
* @param df
|
|
||||||
* @param weights
|
|
||||||
* @param strides
|
|
||||||
* @param rates
|
|
||||||
* @param isSameMode
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides,
|
public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
|
||||||
int[] rates, boolean isSameMode) {
|
@NonNull int[] rates, @NonNull boolean isSameMode) {
|
||||||
return dilation2D(null, df, weights, strides, rates, isSameMode);
|
return dilation2D(null, df, weights, strides, rates, isSameMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,8 +503,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param isSameMode
|
* @param isSameMode
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides,
|
public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
|
||||||
int[] rates, boolean isSameMode) {
|
@NonNull int[] rates, @NonNull boolean isSameMode) {
|
||||||
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
|
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
@ -555,21 +524,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param sameMode If true: use same mode padding. If false
|
* @param sameMode If true: use same mode padding. If false
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
||||||
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
|
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
|
* See {@link #im2Col(String, SDVariable, Conv2DConfig)}.
|
||||||
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
|
|
||||||
*
|
|
||||||
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
|
|
||||||
* @param config Convolution configuration for the im2col operation
|
|
||||||
* @return Im2Col output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable im2Col(SDVariable in, Conv2DConfig config) {
|
public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
return im2Col(null, in, config);
|
return im2Col(null, in, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -582,20 +546,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Convolution configuration for the im2col operation
|
* @param config Convolution configuration for the im2col operation
|
||||||
* @return Im2Col output variable
|
* @return Im2Col output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) {
|
public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
SDVariable ret = f().im2Col(in, config);
|
SDVariable ret = f().im2Col(in, config);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D convolution layer operation - local response normalization
|
* See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}.
|
||||||
*
|
|
||||||
* @param inputs the inputs to lrn
|
|
||||||
* @param lrnConfig the configuration
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) {
|
public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) {
|
||||||
return localResponseNormalization(null, inputs, lrnConfig);
|
return localResponseNormalization(null, inputs, lrnConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -607,8 +567,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param lrnConfig the configuration
|
* @param lrnConfig the configuration
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable localResponseNormalization(String name, SDVariable input,
|
public SDVariable localResponseNormalization(String name, @NonNull SDVariable input,
|
||||||
LocalResponseNormalizationConfig lrnConfig) {
|
@NonNull LocalResponseNormalizationConfig lrnConfig) {
|
||||||
validateFloatingPoint("local response normalization", input);
|
validateFloatingPoint("local response normalization", input);
|
||||||
SDVariable ret = f().localResponseNormalization(input, lrnConfig);
|
SDVariable ret = f().localResponseNormalization(input, lrnConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -616,14 +576,9 @@ public class SDCNN extends SDOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - max pooling 2d
|
* See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param pooling2DConfig the configuration
|
|
||||||
* @return Result after applying max pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
return maxPooling2d(null, input, pooling2DConfig);
|
return maxPooling2d(null, input, pooling2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -636,22 +591,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling2DConfig the configuration
|
* @param pooling2DConfig the configuration
|
||||||
* @return Result after applying max pooling on the input
|
* @return Result after applying max pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
validateNumerical("maxPooling2d", input);
|
validateNumerical("maxPooling2d", input);
|
||||||
SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
|
SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - max pooling 3d operation.
|
* See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param pooling3DConfig the configuration
|
|
||||||
* @return Result after applying max pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
return maxPooling3d(null, input, pooling3DConfig);
|
return maxPooling3d(null, input, pooling3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -665,7 +614,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling3DConfig the configuration
|
* @param pooling3DConfig the configuration
|
||||||
* @return Result after applying max pooling on the input
|
* @return Result after applying max pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
validateNumerical("maxPooling3d", input);
|
validateNumerical("maxPooling3d", input);
|
||||||
SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
|
SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -673,21 +622,30 @@ public class SDCNN extends SDOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation without bias
|
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]
|
|
||||||
* May be null
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of separable convolution 2d operation
|
|
||||||
*/
|
*/
|
||||||
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
|
public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
Conv2DConfig config) {
|
@NonNull Conv2DConfig config) {
|
||||||
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
|
@NonNull Conv2DConfig config) {
|
||||||
|
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
|
SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
|
return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with optional bias
|
* Separable 2D convolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
@ -700,8 +658,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of separable convolution 2d operation
|
* @return result of separable convolution 2d operation
|
||||||
*/
|
*/
|
||||||
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
|
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
SDVariable bias, Conv2DConfig config) {
|
SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
validateFloatingPoint("separableConv2d", "input", layerInput);
|
validateFloatingPoint("separableConv2d", "input", layerInput);
|
||||||
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
|
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
|
||||||
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
|
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
|
||||||
|
@ -712,18 +670,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[2] = pointWeights;
|
arr[2] = pointWeights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[3] = bias;
|
arr[3] = bias;
|
||||||
return sconv2d(arr, config);
|
return sconv2d(name, arr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with/without optional bias
|
* See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
|
|
||||||
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
|
||||||
* @param conv2DConfig the configuration
|
|
||||||
* @return result of separable convolution 2d operation
|
|
||||||
*/
|
*/
|
||||||
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
|
||||||
return sconv2d(null, inputs, conv2DConfig);
|
return sconv2d(null, inputs, conv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -736,7 +689,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param conv2DConfig the configuration
|
* @param conv2DConfig the configuration
|
||||||
* @return result of separable convolution 2d operation
|
* @return result of separable convolution 2d operation
|
||||||
*/
|
*/
|
||||||
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateFloatingPoint("sconv2d", v);
|
validateFloatingPoint("sconv2d", v);
|
||||||
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
|
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
|
||||||
|
@ -747,7 +700,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) {
|
public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
|
||||||
return spaceToBatch(null, x, blocks, padding);
|
return spaceToBatch(null, x, blocks, padding);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,7 +715,7 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) {
|
public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
|
||||||
SDVariable ret = f().spaceToBatch(x, blocks, padding);
|
SDVariable ret = f().spaceToBatch(x, blocks, padding);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
@ -770,7 +723,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* @see #spaceToDepth(String, SDVariable, int, String)
|
* @see #spaceToDepth(String, SDVariable, int, String)
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) {
|
public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
|
||||||
return spaceToDepth(null, x, blockSize, dataFormat);
|
return spaceToDepth(null, x, blockSize, dataFormat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -788,23 +741,39 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #depthToSpace(String, SDVariable, int, String)
|
* @see #depthToSpace(String, SDVariable, int, String)
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) {
|
public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
|
||||||
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
|
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
|
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
|
||||||
|
* scale is used for both height and width dimensions.
|
||||||
*
|
*
|
||||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
* @param scale The scale for both height and width dimensions.
|
||||||
* @param scale Scale to upsample in both H and W dimensions
|
|
||||||
* @return Upsampled input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable upsampling2d(SDVariable input, int scale) {
|
public SDVariable upsampling2d(@NonNull SDVariable input, int scale) {
|
||||||
return upsampling2d(null, input, true, scale, scale);
|
return upsampling2d(null, input, true, scale, scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
|
||||||
|
* scale is used for both height and width dimensions.
|
||||||
|
*
|
||||||
|
* @param scale The scale for both height and width dimensions.
|
||||||
|
*/
|
||||||
|
public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) {
|
||||||
|
return upsampling2d(name, input, true, scale, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)}.
|
||||||
|
*/
|
||||||
|
public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||||
|
return upsampling2d(null, input, nchw, scaleH, scaleW);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - Upsampling 2d
|
* 2D Convolution layer operation - Upsampling 2d
|
||||||
*
|
*
|
||||||
|
@ -814,33 +783,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param scaleW Scale to upsample in width dimension
|
* @param scaleW Scale to upsample in width dimension
|
||||||
* @return Upsampled input
|
* @return Upsampled input
|
||||||
*/
|
*/
|
||||||
public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||||
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
|
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
|
|
||||||
*
|
|
||||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
|
||||||
* @param scale Scale to upsample in both H and W dimensions
|
|
||||||
* @return Upsampled input
|
|
||||||
*/
|
|
||||||
public SDVariable upsampling2d(String name, SDVariable input, int scale) {
|
|
||||||
return upsampling2d(name, input, true, scale, scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 2D Convolution layer operation - Upsampling 2d
|
|
||||||
*
|
|
||||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
|
||||||
* or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
|
|
||||||
* @param scaleH Scale to upsample in height dimension
|
|
||||||
* @param scaleW Scale to upsample in width dimension
|
|
||||||
* @return Upsampled input
|
|
||||||
*/
|
|
||||||
public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
|
||||||
return upsampling2d(null, input, nchw, scaleH, scaleW);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
|
public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
|
||||||
super(null, sameDiff, new SDVariable[]{input}, false);
|
super(sameDiff, new SDVariable[]{input});
|
||||||
if (arrayInput != null) {
|
|
||||||
addInputArgument(arrayInput);
|
|
||||||
}
|
|
||||||
if (arrayOutput != null) {
|
|
||||||
addOutputArgument(arrayOutput);
|
|
||||||
}
|
|
||||||
config.setType(Pooling2D.Pooling2DType.AVG);
|
config.setType(Pooling2D.Pooling2DType.AVG);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||||
|
super(new INDArray[]{input}, wrapOrNull(output));
|
||||||
|
config.setType(Pooling2D.Pooling2DType.AVG);
|
||||||
|
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
protected Conv1DConfig config;
|
protected Conv1DConfig config;
|
||||||
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
|
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public Conv1D(SameDiff sameDiff,
|
public Conv1D(SameDiff sameDiff,
|
||||||
SDVariable[] inputFunctions,
|
SDVariable[] inputFunctions,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
Conv1DConfig config) {
|
Conv1DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputFunctions);
|
||||||
this.sameDiff = sameDiff;
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initConfig(Conv1DConfig config){
|
||||||
this.config = config;
|
this.config = config;
|
||||||
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this);
|
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
protected Conv2DConfig config;
|
protected Conv2DConfig config;
|
||||||
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
|
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public Conv2D(SameDiff sameDiff,
|
public Conv2D(SameDiff sameDiff,
|
||||||
SDVariable[] inputFunctions,
|
SDVariable[] inputFunctions,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
Conv2DConfig config) {
|
Conv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputFunctions);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void initConfig(Conv2DConfig config){
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
||||||
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
||||||
INVALID_CONFIGURATION,
|
INVALID_CONFIGURATION,
|
||||||
config.getSH(), config.getPH(), config.getDW());
|
config.getSH(), config.getPH(), config.getDW());
|
||||||
addArgs();
|
addArgs();
|
||||||
if(sameDiff != null) {
|
|
||||||
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
|
@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
|
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
|
||||||
.sameDiff(sameDiff)
|
.sameDiff(sameDiff)
|
||||||
.config(config)
|
.config(config)
|
||||||
.outputs(outputArguments())
|
|
||||||
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
||||||
.build();
|
.build();
|
||||||
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());
|
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());
|
||||||
|
|
|
@ -37,8 +37,8 @@ import java.util.List;
|
||||||
public class Conv2DDerivative extends Conv2D {
|
public class Conv2DDerivative extends Conv2D {
|
||||||
|
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) {
|
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) {
|
||||||
super(sameDiff, inputFunctions, inputArrays, outputs, config);
|
super(sameDiff, inputFunctions, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Conv2DDerivative() {}
|
public Conv2DDerivative() {}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
|
@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp {
|
||||||
public Conv3D() {
|
public Conv3D() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs,
|
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
|
||||||
Conv3DConfig conv3DConfig) {
|
super(sameDiff, inputFunctions);
|
||||||
super(null, sameDiff, inputFunctions, false);
|
initConfig(config);
|
||||||
setSameDiff(sameDiff);
|
}
|
||||||
|
|
||||||
if (inputs != null)
|
public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){
|
||||||
addInputArgument(inputs);
|
super(inputs, outputs);
|
||||||
if (outputs != null)
|
initConfig(config);
|
||||||
addOutputArgument(outputs);
|
}
|
||||||
this.config = conv3DConfig;
|
|
||||||
|
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initConfig(Conv3DConfig config){
|
||||||
|
this.config = config;
|
||||||
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
||||||
INVALID_CONFIGURATION,
|
INVALID_CONFIGURATION,
|
||||||
config.getSW(), config.getPH(), config.getDW());
|
config.getSW(), config.getPH(), config.getDW());
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
||||||
|
|
||||||
//for (val arg: iArgs())
|
|
||||||
// System.out.println(getIArgument(arg));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp {
|
||||||
inputs.add(f1.get(0));
|
inputs.add(f1.get(0));
|
||||||
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
|
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
|
||||||
.conv3DConfig(config)
|
.conv3DConfig(config)
|
||||||
.inputFunctions(args())
|
|
||||||
.outputs(outputArguments())
|
|
||||||
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
||||||
.sameDiff(sameDiff)
|
.sameDiff(sameDiff)
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D {
|
||||||
public Conv3DDerivative() {}
|
public Conv3DDerivative() {}
|
||||||
|
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) {
|
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) {
|
||||||
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig);
|
super(sameDiff, inputFunctions, conv3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
protected DeConv2DConfig config;
|
protected DeConv2DConfig config;
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public DeConv2D(SameDiff sameDiff,
|
public DeConv2D(SameDiff sameDiff,
|
||||||
SDVariable[] inputs,
|
SDVariable[] inputs,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
DeConv2DConfig config) {
|
DeConv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputs);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
||||||
if (inputArrays != null) {
|
addArgs();
|
||||||
addInputArgument(inputArrays);
|
|
||||||
}
|
|
||||||
if (outputs != null) {
|
|
||||||
addOutputArgument(outputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this);
|
}
|
||||||
sameDiff.addArgsFor(inputs, this);
|
|
||||||
|
public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D {
|
||||||
public DeConv2DDerivative() {}
|
public DeConv2DDerivative() {}
|
||||||
|
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) {
|
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) {
|
||||||
super(sameDiff, inputs, inputArrays, outputs, config);
|
super(sameDiff, inputs, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp {
|
||||||
|
|
||||||
protected DeConv2DConfig config;
|
protected DeConv2DConfig config;
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public DeConv2DTF(SameDiff sameDiff,
|
public DeConv2DTF(SameDiff sameDiff,
|
||||||
SDVariable[] inputs,
|
SDVariable[] inputs,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
DeConv2DConfig config) {
|
DeConv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputs);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
||||||
if (inputArrays != null) {
|
|
||||||
addInputArgument(inputArrays);
|
|
||||||
}
|
|
||||||
if (outputs != null) {
|
|
||||||
addOutputArgument(outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this);
|
|
||||||
sameDiff.addArgsFor(inputs, this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp {
|
||||||
|
|
||||||
protected DeConv3DConfig config;
|
protected DeConv3DConfig config;
|
||||||
|
|
||||||
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||||
super(sameDiff, toArr(input, weights, bias));
|
super(sameDiff, toArr(input, weights, bias));
|
||||||
this.config = config;
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
|
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
|
||||||
if(bias != null){
|
if(bias != null){
|
||||||
return new SDVariable[]{input, weights, bias};
|
return new SDVariable[]{input, weights, bias};
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
protected Conv2DConfig config;
|
protected Conv2DConfig config;
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public DepthwiseConv2D(SameDiff sameDiff,
|
public DepthwiseConv2D(SameDiff sameDiff,
|
||||||
SDVariable[] inputFunctions,
|
SDVariable[] inputFunctions,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
Conv2DConfig config) {
|
Conv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputFunctions);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
}
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
|
||||||
|
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DepthwiseConv2D() {
|
public DepthwiseConv2D() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
||||||
protected LocalResponseNormalizationConfig config;
|
protected LocalResponseNormalizationConfig config;
|
||||||
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions,
|
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace,
|
||||||
INDArray[] inputs, INDArray[] outputs,boolean inPlace,
|
|
||||||
LocalResponseNormalizationConfig config) {
|
LocalResponseNormalizationConfig config) {
|
||||||
super(null,sameDiff, inputFunctions, inPlace);
|
super(null,sameDiff, inputFunctions, inPlace);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
if(inputs != null) {
|
addArgs();
|
||||||
addInputArgument(inputs);
|
|
||||||
}
|
|
||||||
if(outputs!= null) {
|
|
||||||
addOutputArgument(outputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
|
||||||
|
super(new INDArray[]{input}, wrapOrNull(output));
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,8 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
|
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) {
|
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) {
|
||||||
super(sameDiff, inputFunctions, inputs, outputs, inPlace, config);
|
super(sameDiff, inputFunctions, inPlace, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LocalResponseNormalizationDerivative() {}
|
public LocalResponseNormalizationDerivative() {}
|
||||||
|
|
|
@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp {
|
||||||
public MaxPooling2D() {
|
public MaxPooling2D() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
@SuppressWarnings("Used in lombok")
|
@SuppressWarnings("Used in lombok")
|
||||||
public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
|
public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
|
||||||
super(null, sameDiff, new SDVariable[]{input}, false);
|
super(null, sameDiff, new SDVariable[]{input}, false);
|
||||||
if (arrayInput != null) {
|
|
||||||
addInputArgument(arrayInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (arrayOutput != null) {
|
|
||||||
addOutputArgument(arrayOutput);
|
|
||||||
}
|
|
||||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||||
super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output});
|
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
|
@ -16,8 +16,14 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pooling2D operation
|
* Pooling2D operation
|
||||||
|
@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp {
|
||||||
|
|
||||||
public Pooling2D() {}
|
public Pooling2D() {}
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
@SuppressWarnings("Used in lombok")
|
@SuppressWarnings("Used in lombok")
|
||||||
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) {
|
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,
|
||||||
super(null,sameDiff, inputs, false);
|
Pooling2DConfig config) {
|
||||||
if(arrayInputs != null) {
|
super(null, sameDiff, inputs, false);
|
||||||
addInputArgument(arrayInputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(arrayOutputs != null) {
|
|
||||||
addOutputArgument(arrayOutputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||||
|
super(new INDArray[]{input}, wrapOrNull(output));
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -36,8 +37,12 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class Pooling2DDerivative extends Pooling2D {
|
public class Pooling2DDerivative extends Pooling2D {
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) {
|
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) {
|
||||||
super(sameDiff, inputs, arrayInputs, arrayOutputs, config);
|
super(sameDiff, inputs, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){
|
||||||
|
super(new INDArray[]{input, grad}, wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Pooling2DDerivative() {}
|
public Pooling2DDerivative() {}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -39,9 +40,17 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SConv2D extends Conv2D {
|
public class SConv2D extends Conv2D {
|
||||||
|
|
||||||
@Builder(builderMethodName = "sBuilder")
|
@Builder(builderMethodName = "sameDiffSBuilder")
|
||||||
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
|
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
|
||||||
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
|
super(sameDiff, inputFunctions, conv2DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||||
|
super(inputs, outputs, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SConv2D() {}
|
public SConv2D() {}
|
||||||
|
|
|
@ -38,8 +38,8 @@ import java.util.List;
|
||||||
public class SConv2DDerivative extends SConv2D {
|
public class SConv2DDerivative extends SConv2D {
|
||||||
|
|
||||||
@Builder(builderMethodName = "sDerviativeBuilder")
|
@Builder(builderMethodName = "sDerviativeBuilder")
|
||||||
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
|
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
|
||||||
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
|
super(sameDiff, inputFunctions, conv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SConv2DDerivative() {}
|
public SConv2DDerivative() {}
|
||||||
|
|
|
@ -235,10 +235,7 @@ public class Convolution {
|
||||||
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
|
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
|
||||||
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
|
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
|
||||||
double extra, int virtualHeight, int virtualWidth, INDArray out) {
|
double extra, int virtualHeight, int virtualWidth, INDArray out) {
|
||||||
Pooling2D pooling = Pooling2D.builder()
|
Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder()
|
||||||
.arrayInputs(new INDArray[]{img})
|
|
||||||
.arrayOutputs(new INDArray[]{out})
|
|
||||||
.config(Pooling2DConfig.builder()
|
|
||||||
.dH(dh)
|
.dH(dh)
|
||||||
.dW(dw)
|
.dW(dw)
|
||||||
.extra(extra)
|
.extra(extra)
|
||||||
|
@ -251,8 +248,7 @@ public class Convolution {
|
||||||
.sW(sx)
|
.sW(sx)
|
||||||
.type(type)
|
.type(type)
|
||||||
.divisor(divisor)
|
.divisor(divisor)
|
||||||
.build())
|
.build());
|
||||||
.build();
|
|
||||||
Nd4j.getExecutioner().execAndReturn(pooling);
|
Nd4j.getExecutioner().execAndReturn(pooling);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
|
@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
INDArray input = Nd4j.create(inSize);
|
INDArray input = Nd4j.create(inSize);
|
||||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
|
||||||
.arrayInput(input)
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
||||||
|
|
||||||
|
@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
//Test backprop:
|
//Test backprop:
|
||||||
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
|
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf);
|
||||||
.arrayInputs(new INDArray[]{input, grad})
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
||||||
assertEquals(1, outSizesBP.size());
|
assertEquals(1, outSizesBP.size());
|
||||||
|
@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
INDArray input = Nd4j.create(inSize);
|
INDArray input = Nd4j.create(inSize);
|
||||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
|
||||||
.arrayInput(input)
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
||||||
assertEquals(1, outSizes.size());
|
assertEquals(1, outSizes.size());
|
||||||
|
@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
INDArray grad = Nd4j.create(exp);
|
INDArray grad = Nd4j.create(exp);
|
||||||
|
|
||||||
//Test backprop:
|
//Test backprop:
|
||||||
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
|
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf);
|
||||||
.arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output)
|
|
||||||
.arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
||||||
assertEquals(1, outSizesBP.size());
|
assertEquals(1, outSizesBP.size());
|
||||||
|
@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.isSameMode(false)
|
.isSameMode(false)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv2d(vars, c);
|
SDVariable out = sd.cnn().conv2d("conv", vars, c);
|
||||||
out = sd.nn().tanh("out", out);
|
out = sd.nn().tanh("out", out);
|
||||||
|
|
||||||
INDArray outArr = sd.execAndEndResult();
|
INDArray outArr = sd.execAndEndResult();
|
||||||
|
|
Loading…
Reference in New Issue