DL4J + Keras import: Causal Conv1D support (#107)
* Keras causal conv1d support first steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Causal conv mode Signed-off-by: AlexDBlack <blacka101@gmail.com> * Gradient check and fixes for causal conv1d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Conv1D import and testing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small keras test fix Signed-off-by: Alex Black <blacka101@gmail.com> * Don't allow setting causal convolution mode to conv2d/3d layers Signed-off-by: Alex Black <blacka101@gmail.com> * More robustly infer nIn for recurrent layers for ambiguous NCW and NWC cases Signed-off-by: Alex Black <blacka101@gmail.com> * Polish and cleanup Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
578a5abb68
commit
9cc8803b8d
|
@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
|||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.util.Convolution1DUtils;
|
||||
import org.deeplearning4j.util.ConvolutionUtils;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCnn1Causal() {
|
||||
int convNIn = 2;
|
||||
int convNOut1 = 3;
|
||||
int convNOut2 = 4;
|
||||
int finalNOut = 3;
|
||||
|
||||
int[] lengths = {11, 12, 13, 9, 10, 11};
|
||||
int[] kernels = {2, 3, 2, 4, 2, 3};
|
||||
int[] dilations = {1, 1, 2, 1, 2, 1};
|
||||
int[] strides = {1, 2, 1, 2, 1, 1};
|
||||
boolean[] masks = {false, true, false, true, false, true};
|
||||
boolean[] hasB = {true, false, true, false, true, true};
|
||||
|
||||
for (int i = 0; i < lengths.length; i++) {
|
||||
int length = lengths[i];
|
||||
int k = kernels[i];
|
||||
int d = dilations[i];
|
||||
int st = strides[i];
|
||||
boolean mask = masks[i];
|
||||
boolean hasBias = hasB[i];
|
||||
//TODO has bias
|
||||
String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length;
|
||||
log.info("Starting test: " + s);
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.DOUBLE)
|
||||
.updater(new NoOp())
|
||||
.activation(Activation.TANH)
|
||||
.weightInit(new NormalDistribution(0, 1))
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(new Convolution1DLayer.Builder().kernelSize(k)
|
||||
.dilation(d)
|
||||
.hasBias(hasBias)
|
||||
.convolutionMode(ConvolutionMode.Causal)
|
||||
.stride(st).nIn(convNIn).nOut(convNOut1)
|
||||
.build())
|
||||
.layer(new Convolution1DLayer.Builder().kernelSize(k)
|
||||
.dilation(d)
|
||||
.convolutionMode(ConvolutionMode.Causal)
|
||||
.stride(st).nIn(convNOut1).nOut(convNOut2)
|
||||
.build())
|
||||
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
|
||||
.setInputType(InputType.recurrent(convNIn, length)).build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
|
||||
INDArray fm = null;
|
||||
if (mask) {
|
||||
fm = Nd4j.create(2, length);
|
||||
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
|
||||
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1);
|
||||
}
|
||||
|
||||
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
|
||||
long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
|
||||
|
||||
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2);
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null);
|
||||
|
||||
assertTrue(s, gradOK);
|
||||
TestUtils.testModelSerialization(net);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
|||
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConv1dCausalAllowed(){
|
||||
new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||
new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConv2dNoCausalAllowed(){
|
||||
|
||||
try{
|
||||
new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||
}
|
||||
|
||||
try{
|
||||
new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||
}
|
||||
|
||||
try{
|
||||
new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||
}
|
||||
|
||||
try{
|
||||
new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||
}
|
||||
|
||||
try{
|
||||
new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConv3dNoCausalAllowed(){
|
||||
try{
|
||||
new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||
}
|
||||
|
||||
try{
|
||||
new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -356,6 +356,10 @@ public class KerasLayer {
|
|||
return this.layer;
|
||||
}
|
||||
|
||||
public void setLayer(Layer layer){
|
||||
this.layer = layer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether this Keras layer maps to a DL4J Vertex.
|
||||
*
|
||||
|
|
|
@ -233,6 +233,7 @@ public class KerasLayerConfiguration {
|
|||
private final String LAYER_BORDER_MODE_SAME = "same";
|
||||
private final String LAYER_BORDER_MODE_VALID = "valid";
|
||||
private final String LAYER_BORDER_MODE_FULL = "full";
|
||||
private final String LAYER_BORDER_MODE_CAUSAL = "causal";
|
||||
|
||||
/* Noise layers */
|
||||
private final String LAYER_FIELD_RATE = "rate";
|
||||
|
|
|
@ -124,7 +124,26 @@ public class KerasInput extends KerasLayer {
|
|||
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
|
||||
break;
|
||||
case 2:
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
||||
if(this.dimOrder != null) {
|
||||
switch (this.dimOrder) {
|
||||
case TENSORFLOW: //NWC == channels_last
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
||||
break;
|
||||
case THEANO: //NCW == channels_first
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||
break;
|
||||
case NONE:
|
||||
//Assume RNN in [mb, seqLen, size] format
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
|
||||
}
|
||||
} else {
|
||||
//Assume RNN in [mb, seqLen, size] format
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||
}
|
||||
|
||||
break;
|
||||
case 3:
|
||||
switch (this.dimOrder) {
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
|
|||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer {
|
|||
*/
|
||||
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
|
||||
if (type instanceof InputType.InputTypeFeedForward) {
|
||||
this.layer = new LossLayer.Builder(loss).name(this.layerName).build();
|
||||
this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||
}
|
||||
else if (type instanceof InputType.InputTypeRecurrent) {
|
||||
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build();
|
||||
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||
}
|
||||
else if (type instanceof InputType.InputTypeConvolutional) {
|
||||
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build();
|
||||
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||
} else {
|
||||
throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
|
||||
+ "got : " + type.toString());
|
||||
|
|
|
@ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer {
|
|||
public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
super(layerConfig, enforceTrainingConfig);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution {
|
|||
break;
|
||||
|
||||
case THEANO:
|
||||
paramValue = kerasParamValue.permute(2, 1, 0);
|
||||
paramValue = paramValue.reshape(
|
||||
paramValue.size(0), paramValue.size(1),
|
||||
paramValue.size(2), 1).dup();
|
||||
for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) {
|
||||
INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup();
|
||||
double[] flattenedFilter = copyFilter.ravel().data().asDouble();
|
||||
ArrayUtils.reverse(flattenedFilter);
|
||||
INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape());
|
||||
INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3);
|
||||
inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType()));
|
||||
}
|
||||
//Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1]
|
||||
long k = kerasParamValue.size(0);
|
||||
long nIn = kerasParamValue.size(1);
|
||||
long nOut = kerasParamValue.size(2);
|
||||
paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1);
|
||||
break;
|
||||
default:
|
||||
throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());
|
||||
|
|
|
@ -264,7 +264,8 @@ public class KerasConvolutionUtils {
|
|||
} else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
|
||||
borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
|
||||
convolutionMode = ConvolutionMode.Truncate;
|
||||
|
||||
} else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) {
|
||||
convolutionMode = ConvolutionMode.Causal;
|
||||
} else {
|
||||
throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
|
||||
}
|
||||
|
|
|
@ -23,11 +23,13 @@ import lombok.val;
|
|||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
|
@ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer {
|
|||
.biasInit(0.0) // TODO: this is incorrect
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||
if(nIn != null)
|
||||
builder.setNIn(nIn);
|
||||
if (biasConstraint != null)
|
||||
builder.constrainBias(biasConstraint);
|
||||
if (weightConstraint != null)
|
||||
|
@ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer {
|
|||
log.warn("Attemping to set weights for unknown parameters: "
|
||||
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
||||
}
|
||||
|
||||
|
||||
FeedForwardLayer ffl;
|
||||
if(this.layer instanceof BaseWrapperLayer){
|
||||
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
|
||||
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
||||
} else {
|
||||
ffl = (FeedForwardLayer) this.layer;
|
||||
}
|
||||
if(ffl.getNIn() != wRows){
|
||||
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
|
||||
//We can reliably infer nIn from the shape of the weights array however
|
||||
ffl.setNIn(wRows);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
|
@ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||
if(nIn != null)
|
||||
builder.setNIn(nIn);
|
||||
if (biasConstraint != null)
|
||||
builder.constrainBias(biasConstraint);
|
||||
if (weightConstraint != null)
|
||||
|
@ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
log.warn("Attemping to set weights for unknown parameters: "
|
||||
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
||||
}
|
||||
|
||||
FeedForwardLayer ffl;
|
||||
if(this.layer instanceof BaseWrapperLayer){
|
||||
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
|
||||
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
||||
} else {
|
||||
ffl = (FeedForwardLayer) this.layer;
|
||||
}
|
||||
if(ffl.getNIn() != W.rows()){
|
||||
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
|
||||
//We can reliably infer nIn from the shape of the weights array however
|
||||
ffl.setNIn(W.rows());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer {
|
|||
@Override
|
||||
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
|
||||
|
||||
Map<String, INDArray> forwardWeights = getUnderlyingWeights(weights, "forward");
|
||||
Map<String, INDArray> backwardWeights = getUnderlyingWeights(weights, "backward");
|
||||
Map<String, INDArray> forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward");
|
||||
Map<String, INDArray> backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward");
|
||||
|
||||
this.weights = new HashMap<>();
|
||||
|
||||
|
@ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer {
|
|||
}
|
||||
|
||||
|
||||
private Map<String, INDArray> getUnderlyingWeights(Map<String, INDArray> weights, String direction)
|
||||
private Map<String, INDArray> getUnderlyingWeights(Layer l, Map<String, INDArray> weights, String direction)
|
||||
throws InvalidKerasConfigurationException {
|
||||
int keras1SubstringLength;
|
||||
if (kerasRnnlayer instanceof KerasLSTM)
|
||||
|
@ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer {
|
|||
weights = newWeights;
|
||||
}
|
||||
|
||||
Layer layerBefore = kerasRnnlayer.getLayer();
|
||||
kerasRnnlayer.setLayer(l);
|
||||
kerasRnnlayer.setWeights(weights);
|
||||
return kerasRnnlayer.getWeights();
|
||||
Map<String,INDArray> ret = kerasRnnlayer.getWeights();
|
||||
kerasRnnlayer.setLayer(layerBefore);
|
||||
return ret;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -505,6 +505,17 @@ public class KerasLayerUtils {
|
|||
return nOut;
|
||||
}
|
||||
|
||||
public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
|
||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||
if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){
|
||||
Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM());
|
||||
if(id instanceof Number){
|
||||
return ((Number)id).intValue();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get dropout from Keras layer configuration.
|
||||
*
|
||||
|
|
|
@ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass;
|
|||
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
|
@ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.function.BiFunction;
|
||||
import org.nd4j.linalg.function.Function;
|
||||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
||||
|
@ -58,10 +62,7 @@ import java.io.InputStream;
|
|||
import java.net.URL;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
|
@ -86,7 +87,16 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
@Rule
|
||||
public final TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(String s, INDArray array) {
|
||||
if(array.rank() == 3)
|
||||
return array.permute(0, 2, 1); //NWC to NCW
|
||||
return array;
|
||||
}
|
||||
};
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void fileNotFoundEndToEnd() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||
importEndModelTest(modelPath, null, true, true, false, false);
|
||||
|
@ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
public void importImdbLstmTfKeras1() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void importImdbLstmThKeras1() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void importImdbLstmTfKeras2() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void importImdbLstmThKeras2() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
|
||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
||||
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
model.summary();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCausalCon1D() throws Exception {
|
||||
String[] names = new String[]{
|
||||
"causal_conv1d_k2_s1_d1_cl_model.h5",
|
||||
"causal_conv1d_k2_s1_d2_cl_model.h5",
|
||||
"causal_conv1d_k2_s2_d1_cl_model.h5",
|
||||
"causal_conv1d_k2_s3_d1_cl_model.h5",
|
||||
"causal_conv1d_k3_s1_d1_cl_model.h5",
|
||||
"causal_conv1d_k3_s1_d2_cl_model.h5",
|
||||
"causal_conv1d_k3_s2_d1_cl_model.h5",
|
||||
"causal_conv1d_k3_s3_d1_cl_model.h5",
|
||||
"causal_conv1d_k4_s1_d1_cl_model.h5",
|
||||
"causal_conv1d_k4_s1_d2_cl_model.h5",
|
||||
"causal_conv1d_k4_s2_d1_cl_model.h5",
|
||||
"causal_conv1d_k4_s3_d1_cl_model.h5"
|
||||
};
|
||||
|
||||
for(String name : names ){
|
||||
System.out.println("Starting test: " + name);
|
||||
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
|
||||
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(INDArray i) {
|
||||
//NWC to NCW
|
||||
return i.permute(0, 2, 1);
|
||||
}
|
||||
};
|
||||
|
||||
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||
true, true, false, f, nwc2ncwExpected);
|
||||
Layer l = net.getLayer(0);
|
||||
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
|
||||
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCon1D() throws Exception {
|
||||
String[] names = new String[]{
|
||||
"conv1d_k2_s1_d1_cf_same_model.h5",
|
||||
"conv1d_k2_s1_d1_cf_valid_model.h5",
|
||||
"conv1d_k2_s1_d1_cl_same_model.h5",
|
||||
"conv1d_k2_s1_d1_cl_valid_model.h5",
|
||||
"conv1d_k2_s1_d2_cf_same_model.h5",
|
||||
"conv1d_k2_s1_d2_cf_valid_model.h5",
|
||||
"conv1d_k2_s1_d2_cl_same_model.h5",
|
||||
"conv1d_k2_s1_d2_cl_valid_model.h5",
|
||||
"conv1d_k2_s2_d1_cf_same_model.h5",
|
||||
"conv1d_k2_s2_d1_cf_valid_model.h5",
|
||||
"conv1d_k2_s2_d1_cl_same_model.h5",
|
||||
"conv1d_k2_s2_d1_cl_valid_model.h5",
|
||||
"conv1d_k2_s3_d1_cf_same_model.h5",
|
||||
"conv1d_k2_s3_d1_cf_valid_model.h5",
|
||||
"conv1d_k2_s3_d1_cl_same_model.h5",
|
||||
"conv1d_k2_s3_d1_cl_valid_model.h5",
|
||||
"conv1d_k3_s1_d1_cf_same_model.h5",
|
||||
"conv1d_k3_s1_d1_cf_valid_model.h5",
|
||||
"conv1d_k3_s1_d1_cl_same_model.h5",
|
||||
"conv1d_k3_s1_d1_cl_valid_model.h5",
|
||||
"conv1d_k3_s1_d2_cf_same_model.h5",
|
||||
"conv1d_k3_s1_d2_cf_valid_model.h5",
|
||||
"conv1d_k3_s1_d2_cl_same_model.h5",
|
||||
"conv1d_k3_s1_d2_cl_valid_model.h5",
|
||||
"conv1d_k3_s2_d1_cf_same_model.h5",
|
||||
"conv1d_k3_s2_d1_cf_valid_model.h5",
|
||||
"conv1d_k3_s2_d1_cl_same_model.h5",
|
||||
"conv1d_k3_s2_d1_cl_valid_model.h5",
|
||||
"conv1d_k3_s3_d1_cf_same_model.h5",
|
||||
"conv1d_k3_s3_d1_cf_valid_model.h5",
|
||||
"conv1d_k3_s3_d1_cl_same_model.h5",
|
||||
"conv1d_k3_s3_d1_cl_valid_model.h5",
|
||||
"conv1d_k4_s1_d1_cf_same_model.h5",
|
||||
"conv1d_k4_s1_d1_cf_valid_model.h5",
|
||||
"conv1d_k4_s1_d1_cl_same_model.h5",
|
||||
"conv1d_k4_s1_d1_cl_valid_model.h5",
|
||||
"conv1d_k4_s1_d2_cf_same_model.h5",
|
||||
"conv1d_k4_s1_d2_cf_valid_model.h5",
|
||||
"conv1d_k4_s1_d2_cl_same_model.h5",
|
||||
"conv1d_k4_s1_d2_cl_valid_model.h5",
|
||||
"conv1d_k4_s2_d1_cf_same_model.h5",
|
||||
"conv1d_k4_s2_d1_cf_valid_model.h5",
|
||||
"conv1d_k4_s2_d1_cl_same_model.h5",
|
||||
"conv1d_k4_s2_d1_cl_valid_model.h5",
|
||||
"conv1d_k4_s3_d1_cf_same_model.h5",
|
||||
"conv1d_k4_s3_d1_cf_valid_model.h5",
|
||||
"conv1d_k4_s3_d1_cl_same_model.h5",
|
||||
"conv1d_k4_s3_d1_cl_valid_model.h5",
|
||||
};
|
||||
|
||||
for(String name : names ){
|
||||
System.out.println("Starting test: " + name);
|
||||
String modelPath = "modelimport/keras/examples/conv1d/" + name;
|
||||
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(INDArray i) {
|
||||
//NWC to NCW
|
||||
return i.permute(0, 2, 1);
|
||||
}
|
||||
};
|
||||
|
||||
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(String s, INDArray array) {
|
||||
// if("conv".equals(s)){
|
||||
return array.permute(0, 2, 1);
|
||||
// }
|
||||
}
|
||||
};
|
||||
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||
true, true, false, f, f2);
|
||||
}
|
||||
}
|
||||
|
||||
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
|
||||
return importFunctionalModelH5Test(modelPath, null, false);
|
||||
}
|
||||
|
@ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
|
||||
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
||||
boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
|
||||
return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null);
|
||||
}
|
||||
|
||||
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
||||
boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function<INDArray,INDArray> inputPreProc,
|
||||
BiFunction<String,INDArray,INDArray> expectedPreProc) throws Exception {
|
||||
MultiLayerNetwork model;
|
||||
try(InputStream is = Resources.asStream(modelPath)) {
|
||||
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
|
||||
|
@ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
|
||||
if (checkPredictions) {
|
||||
INDArray input = getInputs(outputsArchive, tfOrdering)[0];
|
||||
if(inputPreProc != null)
|
||||
input = inputPreProc.apply(input);
|
||||
|
||||
Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
|
||||
for (int i = 0; i < model.getLayers().length; i++) {
|
||||
String layerName = model.getLayerNames().get(i);
|
||||
if (activationsKeras.containsKey(layerName)) {
|
||||
INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
|
||||
if (activationsDl4j.shape().length == 3)
|
||||
activationsDl4j = activationsDl4j.permute(0, 2, 1);
|
||||
compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS);
|
||||
|
||||
INDArray exp = activationsKeras.get(layerName);
|
||||
if(expectedPreProc != null)
|
||||
exp = expectedPreProc.apply(layerName, exp);
|
||||
compareINDArrays(layerName, exp, activationsDl4j, EPS);
|
||||
}
|
||||
}
|
||||
|
||||
INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
|
||||
INDArray predictionsDl4j = model.output(input, false);
|
||||
if(expectedPreProc != null)
|
||||
predictionsKeras = expectedPreProc.apply("output", predictionsKeras);
|
||||
compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
|
||||
INDArray outputs = getOutputs(outputsArchive, true)[0];
|
||||
|
||||
|
@ -680,7 +817,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
}
|
||||
val nOut = (int) outputs.size(-1);
|
||||
|
||||
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
|
||||
if(checkAuc)
|
||||
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
|
||||
}
|
||||
|
||||
if (checkGradients && ! SKIP_GRAD_CHECKS) {
|
||||
|
@ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
return predictions;
|
||||
}
|
||||
|
||||
private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) {
|
||||
INDArray diff = a.sub(b.castTo(a.dataType()));
|
||||
private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) {
|
||||
if(!expected.equalShapes(actual)){
|
||||
throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape()));
|
||||
}
|
||||
INDArray diff = expected.sub(actual.castTo(expected.dataType()));
|
||||
double min = diff.minNumber().doubleValue();
|
||||
double max = diff.maxNumber().doubleValue();
|
||||
log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max);
|
||||
log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max);
|
||||
double threshold = 1e-7;
|
||||
double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue()));
|
||||
double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue()));
|
||||
double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue()));
|
||||
double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue()));
|
||||
|
||||
// skip too small absolute inputs
|
||||
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
|
||||
assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps));
|
||||
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
|
||||
assertTrue("Output differs: " + label, eq);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,
|
||||
|
|
|
@ -69,6 +69,18 @@ package org.deeplearning4j.nn.conf;
|
|||
* <br>
|
||||
* <br>
|
||||
* <br>
|
||||
* <b>Causal</b>: Causal padding mode can only be used for 1D convolutional neural networks.<br>
|
||||
* The motivation behind causal padding mode is that the output time steps depend only on current and past time steps.<br>
|
||||
* That is, out[t] (for time t) depends on only on values in[T] for t < T<br>
|
||||
* The output size of 1D convolution/subsampling layers is the same as with SAME convolution mode -
|
||||
* i.e., outSize = ceil( inputSize / stride )<br>
|
||||
* Padding is also the same as SAME mode, but all padding in on the left (start of sequence) instead of being on both
|
||||
* left and right of the input<br>
|
||||
* For more details on causal convolutions, see <a href="https://arxiv.org/abs/1609.03499">WaveNet: A Generative Model For Audio</a>,
|
||||
* section 2.1.
|
||||
* <br>
|
||||
* <br>
|
||||
* <br>
|
||||
* For further information on output sizes for convolutional neural networks, see the "Spatial arrangement" section at
|
||||
* <a href="http://cs231n.github.io/convolutional-networks/">http://cs231n.github.io/convolutional-networks/</a>
|
||||
*
|
||||
|
@ -76,6 +88,6 @@ package org.deeplearning4j.nn.conf;
|
|||
*/
|
||||
public enum ConvolutionMode {
|
||||
|
||||
Strict, Truncate, Same
|
||||
Strict, Truncate, Same, Causal
|
||||
|
||||
}
|
||||
|
|
|
@ -124,6 +124,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
this.setKernelSize((int[]) null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param kernelSize Kernel size
|
||||
* @param stride Stride
|
||||
|
|
|
@ -163,6 +163,12 @@ public class Convolution3D extends ConvolutionLayer {
|
|||
super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
//Causal convolution - allowed for 1D only
|
||||
return false;
|
||||
}
|
||||
|
||||
public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
|
||||
super(kernelSize, stride, padding, dilation, 3);
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
|||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.util.ConvolutionUtils;
|
||||
import org.deeplearning4j.util.ValidationUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
@ -283,6 +284,12 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
super();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
//Causal convolution - allowed for 1D only
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Size of the convolution rows/columns
|
||||
*
|
||||
|
@ -456,6 +463,14 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
|
||||
protected BaseConvBuilder() {}
|
||||
|
||||
protected abstract boolean allowCausal();
|
||||
|
||||
protected void setConvolutionMode(ConvolutionMode convolutionMode){
|
||||
Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
||||
" convolutional neural network layers");
|
||||
this.convolutionMode = convolutionMode;
|
||||
}
|
||||
|
||||
/**
|
||||
* If true (default): include bias parameters in the model. False: no bias.
|
||||
*
|
||||
|
|
|
@ -133,6 +133,12 @@ public class Deconvolution2D extends ConvolutionLayer {
|
|||
super();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
//Causal convolution - allowed for 1D only
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
||||
*
|
||||
|
|
|
@ -133,6 +133,12 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
|
|||
super();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
//Causal convolution - allowed for 1D only
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set channels multiplier for depth-wise convolution
|
||||
*
|
||||
|
|
|
@ -184,6 +184,12 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
|||
super();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
//Causal convolution - allowed for 1D only
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set channels multiplier of channels-wise step in separable convolution
|
||||
*
|
||||
|
|
|
@ -167,6 +167,11 @@ public class Subsampling1DLayer extends SubsamplingLayer {
|
|||
this(poolingType, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public Builder() {
|
||||
this(DEFAULT_POOLING, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
|
||||
}
|
||||
|
|
|
@ -431,6 +431,12 @@ public class Subsampling3DLayer extends NoParamLayer {
|
|||
this.setPoolingType(poolingType);
|
||||
}
|
||||
|
||||
protected void setConvolutionMode(ConvolutionMode convolutionMode){
|
||||
Preconditions.checkState(convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
||||
" convolutional neural network layers");
|
||||
this.convolutionMode = convolutionMode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
||||
*
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer;
|
|||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.util.ConvolutionUtils;
|
||||
import org.deeplearning4j.util.ValidationUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
@ -270,6 +271,12 @@ public class SubsamplingLayer extends NoParamLayer {
|
|||
super(poolingType);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean allowCausal() {
|
||||
//Only conv1d/subsampling1d can use causal mode
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Kernel size
|
||||
*
|
||||
|
@ -449,6 +456,14 @@ public class SubsamplingLayer extends NoParamLayer {
|
|||
this.eps = eps;
|
||||
}
|
||||
|
||||
protected abstract boolean allowCausal();
|
||||
|
||||
public void setConvolutionMode(ConvolutionMode convolutionMode){
|
||||
Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
||||
" convolutional neural network layers");
|
||||
this.convolutionMode = convolutionMode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
||||
*
|
||||
|
|
|
@ -18,18 +18,30 @@ package org.deeplearning4j.nn.layers.convolution;
|
|||
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.util.Convolution1DUtils;
|
||||
import org.deeplearning4j.util.ConvolutionUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.factory.Broadcast;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 1D (temporal) convolutional layer. Currently, we just subclass off the
|
||||
|
@ -70,6 +82,52 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
Broadcast.mul(epsilon, maskOut, epsilon, 0, 2);
|
||||
}
|
||||
|
||||
if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){
|
||||
Pair<INDArray,INDArray> fwd = causalConv1dForward();
|
||||
IActivation afn = layerConf().getActivationFn();
|
||||
INDArray delta = afn.backprop(fwd.getFirst(), epsilon).getFirst(); //TODO handle activation function params
|
||||
|
||||
//TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support
|
||||
org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf();
|
||||
Conv1DConfig conf = Conv1DConfig.builder()
|
||||
.k(c.getKernelSize()[0])
|
||||
.s(c.getStride()[0])
|
||||
.d(c.getDilation()[0])
|
||||
.p(c.getPadding()[0])
|
||||
.dataFormat(Conv1DConfig.NCW)
|
||||
.paddingMode(PaddingMode.CAUSAL)
|
||||
.build();
|
||||
|
||||
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC]
|
||||
|
||||
INDArray[] inputArrs;
|
||||
INDArray[] outputArrs;
|
||||
INDArray wg = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
wg = wg.reshape(wg.ordering(), wg.size(0), wg.size(1), wg.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] -> [kW, iC, oC]
|
||||
INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
||||
if(layerConf().hasBias()){
|
||||
INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY);
|
||||
b = b.reshape(b.length());
|
||||
inputArrs = new INDArray[]{input.castTo(w.dataType()), w, b, delta};
|
||||
INDArray bg = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||
bg = bg.reshape(bg.length());
|
||||
outputArrs = new INDArray[]{epsOut, wg, bg};
|
||||
} else {
|
||||
inputArrs = new INDArray[]{input.castTo(w.dataType()), w, delta};
|
||||
outputArrs = new INDArray[]{epsOut, wg};
|
||||
}
|
||||
Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf);
|
||||
Nd4j.exec(op);
|
||||
|
||||
Gradient retGradient = new DefaultGradient();
|
||||
if(layerConf().hasBias()){
|
||||
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY));
|
||||
}
|
||||
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c');
|
||||
return new Pair<>(retGradient, epsOut);
|
||||
}
|
||||
|
||||
// add singleton fourth dimension to input and next layer's epsilon
|
||||
epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1);
|
||||
INDArray origInput = input;
|
||||
|
@ -98,6 +156,12 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
@Override
|
||||
protected Pair<INDArray,INDArray> preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
|
||||
assertInputSet(false);
|
||||
|
||||
if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){
|
||||
return causalConv1dForward();
|
||||
}
|
||||
|
||||
|
||||
INDArray origInput = input;
|
||||
input = input.reshape(input.size(0), input.size(1), input.size(2), 1);
|
||||
|
||||
|
@ -113,6 +177,36 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
return preOutput;
|
||||
}
|
||||
|
||||
protected Pair<INDArray,INDArray> causalConv1dForward(){
|
||||
//TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support
|
||||
org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf();
|
||||
Conv1DConfig conf = Conv1DConfig.builder()
|
||||
.k(c.getKernelSize()[0])
|
||||
.s(c.getStride()[0])
|
||||
.d(c.getDilation()[0])
|
||||
.p(c.getPadding()[0])
|
||||
.dataFormat(Conv1DConfig.NCW)
|
||||
.paddingMode(PaddingMode.CAUSAL)
|
||||
.build();
|
||||
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC]
|
||||
|
||||
INDArray[] inputs;
|
||||
if(layerConf().hasBias()){
|
||||
INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY);
|
||||
b = b.reshape(b.length());
|
||||
inputs = new INDArray[]{input.castTo(w.dataType()), w, b};
|
||||
} else {
|
||||
inputs = new INDArray[]{input.castTo(w.dataType()), w};
|
||||
}
|
||||
|
||||
Conv1D op = new Conv1D(inputs, null, conf);
|
||||
List<LongShapeDescriptor> outShape = op.calculateOutputShape();
|
||||
op.setOutputArgument(0, Nd4j.create(outShape.get(0), false));
|
||||
Nd4j.exec(op);
|
||||
return new Pair<>(op.getOutputArgument(0), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
|
||||
INDArray act4d = super.activate(training, workspaceMgr);
|
||||
|
|
|
@ -66,7 +66,7 @@ public class Convolution1DUtils {
|
|||
public static long getOutputSize(long inH, int kernel, int strides, int padding,
|
||||
ConvolutionMode convolutionMode, int dilation) {
|
||||
long eKernel = effectiveKernelSize(kernel, dilation);
|
||||
if (convolutionMode == ConvolutionMode.Same) {
|
||||
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
|
||||
return (int) Math.ceil(inH / ((double) strides));
|
||||
}
|
||||
return (inH - eKernel + 2 * padding) / strides + 1;
|
||||
|
@ -92,7 +92,7 @@ public class Convolution1DUtils {
|
|||
boolean atrous = (eKernel == kernel);
|
||||
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inH, atrous);
|
||||
|
||||
if (convolutionMode == ConvolutionMode.Same) {
|
||||
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
|
||||
int outH = (int) Math.ceil(inH / ((double) strides));
|
||||
return outH;
|
||||
}
|
||||
|
@ -106,8 +106,9 @@ public class Convolution1DUtils {
|
|||
boolean atrous) {
|
||||
|
||||
int inH = inShape;
|
||||
boolean t = convolutionMode == ConvolutionMode.Truncate;
|
||||
|
||||
if (convolutionMode != ConvolutionMode.Same && (eKernel <= 0 || eKernel > inH + 2 * padding)) {
|
||||
if (t && (eKernel <= 0 || eKernel > inH + 2 * padding)) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Invalid input data or configuration: ");
|
||||
if (atrous) sb.append("effective ");
|
||||
|
|
|
@ -121,7 +121,7 @@ public class ConvolutionUtils {
|
|||
int[] inShape = new int[]{inH, inW};
|
||||
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
|
||||
|
||||
if (convolutionMode == ConvolutionMode.Same) {
|
||||
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
|
||||
|
||||
int outH = (int) Math.ceil(inH / ((double) strides[0]));
|
||||
int outW = (int) Math.ceil(inW / ((double) strides[1]));
|
||||
|
@ -142,7 +142,9 @@ public class ConvolutionUtils {
|
|||
int inH = inShape[0];
|
||||
int inW = inShape[1];
|
||||
|
||||
if (convolutionMode != ConvolutionMode.Same && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) {
|
||||
boolean t = (convolutionMode == ConvolutionMode.Truncate);
|
||||
|
||||
if (t && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Invalid input data or configuration: ");
|
||||
if (atrous) sb.append("effective ");
|
||||
|
@ -158,7 +160,7 @@ public class ConvolutionUtils {
|
|||
throw new DL4JInvalidInputException(sb.toString());
|
||||
}
|
||||
|
||||
if (convolutionMode != ConvolutionMode.Same && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) {
|
||||
if (t && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Invalid input data or configuration: ");
|
||||
if (atrous) sb.append("effective ");
|
||||
|
@ -175,8 +177,7 @@ public class ConvolutionUtils {
|
|||
throw new DL4JInvalidInputException(sb.toString());
|
||||
}
|
||||
|
||||
if (eKernel.length == 3 && convolutionMode != ConvolutionMode.Same
|
||||
&& (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) {
|
||||
if (eKernel.length == 3 && t && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) {
|
||||
int inD = inShape[2];
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Invalid input data or configuration: ");
|
||||
|
@ -615,7 +616,7 @@ public class ConvolutionUtils {
|
|||
*/
|
||||
public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm){
|
||||
Preconditions.checkState(in.rank()==2, "Rank must be 2 for cnn1d mask array - shape ", in.shape());
|
||||
if(cm == ConvolutionMode.Same && stride == 1 ){
|
||||
if((cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) && stride == 1 ){
|
||||
return in;
|
||||
}
|
||||
|
||||
|
@ -630,7 +631,7 @@ public class ConvolutionUtils {
|
|||
int[] k = new int[]{kernel,1};
|
||||
int[] s = new int[]{stride, 1};
|
||||
int[] d = new int[]{dilation, 1};
|
||||
if (cm == ConvolutionMode.Same) {
|
||||
if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) {
|
||||
outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation
|
||||
} else {
|
||||
pad = new int[]{padding, 0};
|
||||
|
@ -645,7 +646,7 @@ public class ConvolutionUtils {
|
|||
.sH(s[0]).sW(s[1])
|
||||
.pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1])
|
||||
.dH(d[0]).dW(d[1])
|
||||
.isSameMode(cm== ConvolutionMode.Same)
|
||||
.isSameMode(cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal)
|
||||
.isNHWC(false)
|
||||
.build());
|
||||
|
||||
|
|
Loading…
Reference in New Issue