DL4J + Keras import: Causal Conv1D support (#107)

* Keras causal conv1d support first steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Causal conv mode

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Gradient check and fixes for causal conv1d

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix Conv1D import and testing

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small keras test fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Don't allow setting causal convolution mode to conv2d/3d layers

Signed-off-by: Alex Black <blacka101@gmail.com>

* More robustly infer nIn for recurrent layers for ambiguous NCW and NWC cases

Signed-off-by: Alex Black <blacka101@gmail.com>

* Polish and cleanup

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-12-04 22:52:06 +11:00 committed by GitHub
parent 578a5abb68
commit 9cc8803b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 588 additions and 56 deletions

View File

@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
}
}
}
@Test
public void testCnn1Causal() {
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 3;
int[] lengths = {11, 12, 13, 9, 10, 11};
int[] kernels = {2, 3, 2, 4, 2, 3};
int[] dilations = {1, 1, 2, 1, 2, 1};
int[] strides = {1, 2, 1, 2, 1, 1};
boolean[] masks = {false, true, false, true, false, true};
boolean[] hasB = {true, false, true, false, true, true};
for (int i = 0; i < lengths.length; i++) {
int length = lengths[i];
int k = kernels[i];
int d = dilations[i];
int st = strides[i];
boolean mask = masks[i];
boolean hasBias = hasB[i];
//TODO has bias
String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length;
log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.weightInit(new NormalDistribution(0, 1))
.seed(12345)
.list()
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.hasBias(hasBias)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nIn(convNIn).nOut(convNOut1)
.build())
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nIn(convNOut1).nOut(convNOut2)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
INDArray fm = null;
if (mask) {
fm = Nd4j.create(2, length);
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1);
}
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null);
assertTrue(s, gradOK);
TestUtils.testModelSerialization(net);
}
}
}

View File

@ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
}
}
@Test
public void testConv1dCausalAllowed(){
new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
}
@Test
public void testConv2dNoCausalAllowed(){
try{
new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
}
@Test
public void testConv3dNoCausalAllowed(){
try{
new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
}
}

View File

@ -356,6 +356,10 @@ public class KerasLayer {
return this.layer;
}
public void setLayer(Layer layer){
this.layer = layer;
}
/**
* Whether this Keras layer maps to a DL4J Vertex.
*

View File

@ -233,6 +233,7 @@ public class KerasLayerConfiguration {
private final String LAYER_BORDER_MODE_SAME = "same";
private final String LAYER_BORDER_MODE_VALID = "valid";
private final String LAYER_BORDER_MODE_FULL = "full";
private final String LAYER_BORDER_MODE_CAUSAL = "causal";
/* Noise layers */
private final String LAYER_FIELD_RATE = "rate";

View File

@ -124,7 +124,26 @@ public class KerasInput extends KerasLayer {
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
break;
case 2:
if(this.dimOrder != null) {
switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
break;
case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
break;
case NONE:
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
break;
default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
}
} else {
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
}
break;
case 3:
switch (this.dimOrder) {

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
@ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer {
*/
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
if (type instanceof InputType.InputTypeFeedForward) {
this.layer = new LossLayer.Builder(loss).name(this.layerName).build();
this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
}
else if (type instanceof InputType.InputTypeRecurrent) {
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build();
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
}
else if (type instanceof InputType.InputTypeConvolutional) {
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build();
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
} else {
throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
+ "got : " + type.toString());

View File

@ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer {
public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
}
/**

View File

@ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution {
break;
case THEANO:
paramValue = kerasParamValue.permute(2, 1, 0);
paramValue = paramValue.reshape(
paramValue.size(0), paramValue.size(1),
paramValue.size(2), 1).dup();
for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) {
INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup();
double[] flattenedFilter = copyFilter.ravel().data().asDouble();
ArrayUtils.reverse(flattenedFilter);
INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape());
INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3);
inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType()));
}
//Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1]
long k = kerasParamValue.size(0);
long nIn = kerasParamValue.size(1);
long nOut = kerasParamValue.size(2);
paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1);
break;
default:
throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());

View File

@ -264,7 +264,8 @@ public class KerasConvolutionUtils {
} else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
convolutionMode = ConvolutionMode.Truncate;
} else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) {
convolutionMode = ConvolutionMode.Causal;
} else {
throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
}

View File

@ -23,11 +23,13 @@ import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer {
.biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
if (biasConstraint != null)
builder.constrainBias(biasConstraint);
if (weightConstraint != null)
@ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer {
log.warn("Attemping to set weights for unknown parameters: "
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
}
FeedForwardLayer ffl;
if(this.layer instanceof BaseWrapperLayer){
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
ffl = (FeedForwardLayer)bwl.getUnderlying();
} else {
ffl = (FeedForwardLayer) this.layer;
}
if(ffl.getNIn() != wRows){
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
//We can reliably infer nIn from the shape of the weights array however
ffl.setNIn(wRows);
}
}
/**

View File

@ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer {
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
if (biasConstraint != null)
builder.constrainBias(biasConstraint);
if (weightConstraint != null)
@ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer {
log.warn("Attemping to set weights for unknown parameters: "
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
}
FeedForwardLayer ffl;
if(this.layer instanceof BaseWrapperLayer){
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
ffl = (FeedForwardLayer)bwl.getUnderlying();
} else {
ffl = (FeedForwardLayer) this.layer;
}
if(ffl.getNIn() != W.rows()){
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
//We can reliably infer nIn from the shape of the weights array however
ffl.setNIn(W.rows());
}
}
}

View File

@ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer {
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
Map<String, INDArray> forwardWeights = getUnderlyingWeights(weights, "forward");
Map<String, INDArray> backwardWeights = getUnderlyingWeights(weights, "backward");
Map<String, INDArray> forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward");
Map<String, INDArray> backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward");
this.weights = new HashMap<>();
@ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer {
}
private Map<String, INDArray> getUnderlyingWeights(Map<String, INDArray> weights, String direction)
private Map<String, INDArray> getUnderlyingWeights(Layer l, Map<String, INDArray> weights, String direction)
throws InvalidKerasConfigurationException {
int keras1SubstringLength;
if (kerasRnnlayer instanceof KerasLSTM)
@ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer {
weights = newWeights;
}
Layer layerBefore = kerasRnnlayer.getLayer();
kerasRnnlayer.setLayer(l);
kerasRnnlayer.setWeights(weights);
return kerasRnnlayer.getWeights();
Map<String,INDArray> ret = kerasRnnlayer.getWeights();
kerasRnnlayer.setLayer(layerBefore);
return ret;
}
}

View File

@ -505,6 +505,17 @@ public class KerasLayerUtils {
return nOut;
}
public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){
Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM());
if(id instanceof Number){
return ((Number)id).intValue();
}
}
return null;
}
/**
* Get dropout from Keras layer configuration.
*

View File

@ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
@ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.BiFunction;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
@ -58,10 +62,7 @@ import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.*;
import static org.junit.Assert.*;
@ -86,6 +87,15 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Rule
public final TemporaryFolder testDir = new TemporaryFolder();
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
if(array.rank() == 3)
return array.permute(0, 2, 1); //NWC to NCW
return array;
}
};
@Test(expected = IllegalStateException.class)
public void fileNotFoundEndToEnd() throws Exception {
String modelPath = "modelimport/keras/examples/foo/bar.h5";
@ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
}
@Test
public void importImdbLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
}
@Test
public void importImdbLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
}
@Test
public void importImdbLstmThKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
}
/**
@ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
}
/**
@ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
model.summary();
}
@Test
public void testCausalCon1D() throws Exception {
String[] names = new String[]{
"causal_conv1d_k2_s1_d1_cl_model.h5",
"causal_conv1d_k2_s1_d2_cl_model.h5",
"causal_conv1d_k2_s2_d1_cl_model.h5",
"causal_conv1d_k2_s3_d1_cl_model.h5",
"causal_conv1d_k3_s1_d1_cl_model.h5",
"causal_conv1d_k3_s1_d2_cl_model.h5",
"causal_conv1d_k3_s2_d1_cl_model.h5",
"causal_conv1d_k3_s3_d1_cl_model.h5",
"causal_conv1d_k4_s1_d1_cl_model.h5",
"causal_conv1d_k4_s1_d2_cl_model.h5",
"causal_conv1d_k4_s2_d1_cl_model.h5",
"causal_conv1d_k4_s3_d1_cl_model.h5"
};
for(String name : names ){
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, nwc2ncwExpected);
Layer l = net.getLayer(0);
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
}
}
@Test
public void testCon1D() throws Exception {
String[] names = new String[]{
"conv1d_k2_s1_d1_cf_same_model.h5",
"conv1d_k2_s1_d1_cf_valid_model.h5",
"conv1d_k2_s1_d1_cl_same_model.h5",
"conv1d_k2_s1_d1_cl_valid_model.h5",
"conv1d_k2_s1_d2_cf_same_model.h5",
"conv1d_k2_s1_d2_cf_valid_model.h5",
"conv1d_k2_s1_d2_cl_same_model.h5",
"conv1d_k2_s1_d2_cl_valid_model.h5",
"conv1d_k2_s2_d1_cf_same_model.h5",
"conv1d_k2_s2_d1_cf_valid_model.h5",
"conv1d_k2_s2_d1_cl_same_model.h5",
"conv1d_k2_s2_d1_cl_valid_model.h5",
"conv1d_k2_s3_d1_cf_same_model.h5",
"conv1d_k2_s3_d1_cf_valid_model.h5",
"conv1d_k2_s3_d1_cl_same_model.h5",
"conv1d_k2_s3_d1_cl_valid_model.h5",
"conv1d_k3_s1_d1_cf_same_model.h5",
"conv1d_k3_s1_d1_cf_valid_model.h5",
"conv1d_k3_s1_d1_cl_same_model.h5",
"conv1d_k3_s1_d1_cl_valid_model.h5",
"conv1d_k3_s1_d2_cf_same_model.h5",
"conv1d_k3_s1_d2_cf_valid_model.h5",
"conv1d_k3_s1_d2_cl_same_model.h5",
"conv1d_k3_s1_d2_cl_valid_model.h5",
"conv1d_k3_s2_d1_cf_same_model.h5",
"conv1d_k3_s2_d1_cf_valid_model.h5",
"conv1d_k3_s2_d1_cl_same_model.h5",
"conv1d_k3_s2_d1_cl_valid_model.h5",
"conv1d_k3_s3_d1_cf_same_model.h5",
"conv1d_k3_s3_d1_cf_valid_model.h5",
"conv1d_k3_s3_d1_cl_same_model.h5",
"conv1d_k3_s3_d1_cl_valid_model.h5",
"conv1d_k4_s1_d1_cf_same_model.h5",
"conv1d_k4_s1_d1_cf_valid_model.h5",
"conv1d_k4_s1_d1_cl_same_model.h5",
"conv1d_k4_s1_d1_cl_valid_model.h5",
"conv1d_k4_s1_d2_cf_same_model.h5",
"conv1d_k4_s1_d2_cf_valid_model.h5",
"conv1d_k4_s1_d2_cl_same_model.h5",
"conv1d_k4_s1_d2_cl_valid_model.h5",
"conv1d_k4_s2_d1_cf_same_model.h5",
"conv1d_k4_s2_d1_cf_valid_model.h5",
"conv1d_k4_s2_d1_cl_same_model.h5",
"conv1d_k4_s2_d1_cl_valid_model.h5",
"conv1d_k4_s3_d1_cf_same_model.h5",
"conv1d_k4_s3_d1_cf_valid_model.h5",
"conv1d_k4_s3_d1_cl_same_model.h5",
"conv1d_k4_s3_d1_cl_valid_model.h5",
};
for(String name : names ){
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
// if("conv".equals(s)){
return array.permute(0, 2, 1);
// }
}
};
importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, f2);
}
}
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
return importFunctionalModelH5Test(modelPath, null, false);
}
@ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null);
}
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function<INDArray,INDArray> inputPreProc,
BiFunction<String,INDArray,INDArray> expectedPreProc) throws Exception {
MultiLayerNetwork model;
try(InputStream is = Resources.asStream(modelPath)) {
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
@ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
if (checkPredictions) {
INDArray input = getInputs(outputsArchive, tfOrdering)[0];
if(inputPreProc != null)
input = inputPreProc.apply(input);
Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
for (int i = 0; i < model.getLayers().length; i++) {
String layerName = model.getLayerNames().get(i);
if (activationsKeras.containsKey(layerName)) {
INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
if (activationsDl4j.shape().length == 3)
activationsDl4j = activationsDl4j.permute(0, 2, 1);
compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS);
INDArray exp = activationsKeras.get(layerName);
if(expectedPreProc != null)
exp = expectedPreProc.apply(layerName, exp);
compareINDArrays(layerName, exp, activationsDl4j, EPS);
}
}
INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
INDArray predictionsDl4j = model.output(input, false);
if(expectedPreProc != null)
predictionsKeras = expectedPreProc.apply("output", predictionsKeras);
compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
INDArray outputs = getOutputs(outputsArchive, true)[0];
@ -680,6 +817,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
}
val nOut = (int) outputs.size(-1);
if(checkAuc)
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
}
@ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
return predictions;
}
private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) {
INDArray diff = a.sub(b.castTo(a.dataType()));
private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) {
if(!expected.equalShapes(actual)){
throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape()));
}
INDArray diff = expected.sub(actual.castTo(expected.dataType()));
double min = diff.minNumber().doubleValue();
double max = diff.maxNumber().doubleValue();
log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max);
log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max);
double threshold = 1e-7;
double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue()));
double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue()));
double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue()));
double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue()));
// skip too small absolute inputs
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps));
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
assertTrue("Output differs: " + label, eq);
}
}
private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,

View File

@ -69,6 +69,18 @@ package org.deeplearning4j.nn.conf;
* <br>
* <br>
* <br>
* <b>Causal</b>: Causal padding mode can only be used for 1D convolutional neural networks.<br>
* The motivation behind causal padding mode is that the output time steps depend only on current and past time steps.<br>
* That is, out[t] (for time t) depends on only on values in[T] for t < T<br>
* The output size of 1D convolution/subsampling layers is the same as with SAME convolution mode -
* i.e., outSize = ceil( inputSize / stride )<br>
* Padding is also the same as SAME mode, but all padding in on the left (start of sequence) instead of being on both
* left and right of the input<br>
* For more details on causal convolutions, see <a href="https://arxiv.org/abs/1609.03499">WaveNet: A Generative Model For Audio</a>,
* section 2.1.
* <br>
* <br>
* <br>
* For further information on output sizes for convolutional neural networks, see the "Spatial arrangement" section at
* <a href="http://cs231n.github.io/convolutional-networks/">http://cs231n.github.io/convolutional-networks/</a>
*
@ -76,6 +88,6 @@ package org.deeplearning4j.nn.conf;
*/
public enum ConvolutionMode {
Strict, Truncate, Same
Strict, Truncate, Same, Causal
}

View File

@ -124,6 +124,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
this.setKernelSize((int[]) null);
}
@Override
protected boolean allowCausal() {
return true;
}
/**
* @param kernelSize Kernel size
* @param stride Stride

View File

@ -163,6 +163,12 @@ public class Convolution3D extends ConvolutionLayer {
super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3);
}
@Override
protected boolean allowCausal() {
//Causal convolution - allowed for 1D only
return false;
}
public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
super(kernelSize, stride, padding, dilation, 3);
}

View File

@ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -283,6 +284,12 @@ public class ConvolutionLayer extends FeedForwardLayer {
super();
}
@Override
protected boolean allowCausal() {
//Causal convolution - allowed for 1D only
return false;
}
/**
* Size of the convolution rows/columns
*
@ -456,6 +463,14 @@ public class ConvolutionLayer extends FeedForwardLayer {
protected BaseConvBuilder() {}
protected abstract boolean allowCausal();
protected void setConvolutionMode(ConvolutionMode convolutionMode){
Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
" convolutional neural network layers");
this.convolutionMode = convolutionMode;
}
/**
* If true (default): include bias parameters in the model. False: no bias.
*

View File

@ -133,6 +133,12 @@ public class Deconvolution2D extends ConvolutionLayer {
super();
}
@Override
protected boolean allowCausal() {
//Causal convolution - allowed for 1D only
return false;
}
/**
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
*

View File

@ -133,6 +133,12 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
super();
}
@Override
protected boolean allowCausal() {
//Causal convolution - allowed for 1D only
return false;
}
/**
* Set channels multiplier for depth-wise convolution
*

View File

@ -184,6 +184,12 @@ public class SeparableConvolution2D extends ConvolutionLayer {
super();
}
@Override
protected boolean allowCausal() {
//Causal convolution - allowed for 1D only
return false;
}
/**
* Set channels multiplier of channels-wise step in separable convolution
*

View File

@ -167,6 +167,11 @@ public class Subsampling1DLayer extends SubsamplingLayer {
this(poolingType, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
}
@Override
protected boolean allowCausal() {
return true;
}
public Builder() {
this(DEFAULT_POOLING, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
}

View File

@ -431,6 +431,12 @@ public class Subsampling3DLayer extends NoParamLayer {
this.setPoolingType(poolingType);
}
protected void setConvolutionMode(ConvolutionMode convolutionMode){
Preconditions.checkState(convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
" convolutional neural network layers");
this.convolutionMode = convolutionMode;
}
/**
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
*

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -270,6 +271,12 @@ public class SubsamplingLayer extends NoParamLayer {
super(poolingType);
}
@Override
protected boolean allowCausal() {
//Only conv1d/subsampling1d can use causal mode
return false;
}
/**
* Kernel size
*
@ -449,6 +456,14 @@ public class SubsamplingLayer extends NoParamLayer {
this.eps = eps;
}
protected abstract boolean allowCausal();
public void setConvolutionMode(ConvolutionMode convolutionMode){
Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
" convolutional neural network layers");
this.convolutionMode = convolutionMode;
}
/**
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
*

View File

@ -18,18 +18,30 @@ package org.deeplearning4j.nn.layers.convolution;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.util.Arrays;
import java.util.List;
/**
* 1D (temporal) convolutional layer. Currently, we just subclass off the
@ -70,6 +82,52 @@ public class Convolution1DLayer extends ConvolutionLayer {
Broadcast.mul(epsilon, maskOut, epsilon, 0, 2);
}
if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){
Pair<INDArray,INDArray> fwd = causalConv1dForward();
IActivation afn = layerConf().getActivationFn();
INDArray delta = afn.backprop(fwd.getFirst(), epsilon).getFirst(); //TODO handle activation function params
//TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support
org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf();
Conv1DConfig conf = Conv1DConfig.builder()
.k(c.getKernelSize()[0])
.s(c.getStride()[0])
.d(c.getDilation()[0])
.p(c.getPadding()[0])
.dataFormat(Conv1DConfig.NCW)
.paddingMode(PaddingMode.CAUSAL)
.build();
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC]
INDArray[] inputArrs;
INDArray[] outputArrs;
INDArray wg = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY);
wg = wg.reshape(wg.ordering(), wg.size(0), wg.size(1), wg.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] -> [kW, iC, oC]
INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
if(layerConf().hasBias()){
INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY);
b = b.reshape(b.length());
inputArrs = new INDArray[]{input.castTo(w.dataType()), w, b, delta};
INDArray bg = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY);
bg = bg.reshape(bg.length());
outputArrs = new INDArray[]{epsOut, wg, bg};
} else {
inputArrs = new INDArray[]{input.castTo(w.dataType()), w, delta};
outputArrs = new INDArray[]{epsOut, wg};
}
Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf);
Nd4j.exec(op);
Gradient retGradient = new DefaultGradient();
if(layerConf().hasBias()){
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY));
}
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c');
return new Pair<>(retGradient, epsOut);
}
// add singleton fourth dimension to input and next layer's epsilon
epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1);
INDArray origInput = input;
@ -98,6 +156,12 @@ public class Convolution1DLayer extends ConvolutionLayer {
@Override
protected Pair<INDArray,INDArray> preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(false);
if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){
return causalConv1dForward();
}
INDArray origInput = input;
input = input.reshape(input.size(0), input.size(1), input.size(2), 1);
@ -113,6 +177,36 @@ public class Convolution1DLayer extends ConvolutionLayer {
return preOutput;
}
protected Pair<INDArray,INDArray> causalConv1dForward(){
//TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support
org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf();
Conv1DConfig conf = Conv1DConfig.builder()
.k(c.getKernelSize()[0])
.s(c.getStride()[0])
.d(c.getDilation()[0])
.p(c.getPadding()[0])
.dataFormat(Conv1DConfig.NCW)
.paddingMode(PaddingMode.CAUSAL)
.build();
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC]
INDArray[] inputs;
if(layerConf().hasBias()){
INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY);
b = b.reshape(b.length());
inputs = new INDArray[]{input.castTo(w.dataType()), w, b};
} else {
inputs = new INDArray[]{input.castTo(w.dataType()), w};
}
Conv1D op = new Conv1D(inputs, null, conf);
List<LongShapeDescriptor> outShape = op.calculateOutputShape();
op.setOutputArgument(0, Nd4j.create(outShape.get(0), false));
Nd4j.exec(op);
return new Pair<>(op.getOutputArgument(0), null);
}
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
INDArray act4d = super.activate(training, workspaceMgr);

View File

@ -66,7 +66,7 @@ public class Convolution1DUtils {
public static long getOutputSize(long inH, int kernel, int strides, int padding,
ConvolutionMode convolutionMode, int dilation) {
long eKernel = effectiveKernelSize(kernel, dilation);
if (convolutionMode == ConvolutionMode.Same) {
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
return (int) Math.ceil(inH / ((double) strides));
}
return (inH - eKernel + 2 * padding) / strides + 1;
@ -92,7 +92,7 @@ public class Convolution1DUtils {
boolean atrous = (eKernel == kernel);
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inH, atrous);
if (convolutionMode == ConvolutionMode.Same) {
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
int outH = (int) Math.ceil(inH / ((double) strides));
return outH;
}
@ -106,8 +106,9 @@ public class Convolution1DUtils {
boolean atrous) {
int inH = inShape;
boolean t = convolutionMode == ConvolutionMode.Truncate;
if (convolutionMode != ConvolutionMode.Same && (eKernel <= 0 || eKernel > inH + 2 * padding)) {
if (t && (eKernel <= 0 || eKernel > inH + 2 * padding)) {
StringBuilder sb = new StringBuilder();
sb.append("Invalid input data or configuration: ");
if (atrous) sb.append("effective ");

View File

@ -121,7 +121,7 @@ public class ConvolutionUtils {
int[] inShape = new int[]{inH, inW};
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
if (convolutionMode == ConvolutionMode.Same) {
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
int outH = (int) Math.ceil(inH / ((double) strides[0]));
int outW = (int) Math.ceil(inW / ((double) strides[1]));
@ -142,7 +142,9 @@ public class ConvolutionUtils {
int inH = inShape[0];
int inW = inShape[1];
if (convolutionMode != ConvolutionMode.Same && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) {
boolean t = (convolutionMode == ConvolutionMode.Truncate);
if (t && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) {
StringBuilder sb = new StringBuilder();
sb.append("Invalid input data or configuration: ");
if (atrous) sb.append("effective ");
@ -158,7 +160,7 @@ public class ConvolutionUtils {
throw new DL4JInvalidInputException(sb.toString());
}
if (convolutionMode != ConvolutionMode.Same && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) {
if (t && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) {
StringBuilder sb = new StringBuilder();
sb.append("Invalid input data or configuration: ");
if (atrous) sb.append("effective ");
@ -175,8 +177,7 @@ public class ConvolutionUtils {
throw new DL4JInvalidInputException(sb.toString());
}
if (eKernel.length == 3 && convolutionMode != ConvolutionMode.Same
&& (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) {
if (eKernel.length == 3 && t && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) {
int inD = inShape[2];
StringBuilder sb = new StringBuilder();
sb.append("Invalid input data or configuration: ");
@ -615,7 +616,7 @@ public class ConvolutionUtils {
*/
public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm){
Preconditions.checkState(in.rank()==2, "Rank must be 2 for cnn1d mask array - shape ", in.shape());
if(cm == ConvolutionMode.Same && stride == 1 ){
if((cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) && stride == 1 ){
return in;
}
@ -630,7 +631,7 @@ public class ConvolutionUtils {
int[] k = new int[]{kernel,1};
int[] s = new int[]{stride, 1};
int[] d = new int[]{dilation, 1};
if (cm == ConvolutionMode.Same) {
if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) {
outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation
} else {
pad = new int[]{padding, 0};
@ -645,7 +646,7 @@ public class ConvolutionUtils {
.sH(s[0]).sW(s[1])
.pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1])
.dH(d[0]).dW(d[1])
.isSameMode(cm== ConvolutionMode.Same)
.isSameMode(cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal)
.isNHWC(false)
.build());