DL4J + Keras import: Causal Conv1D support (#107)
* Keras causal conv1d support first steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Causal conv mode Signed-off-by: AlexDBlack <blacka101@gmail.com> * Gradient check and fixes for causal conv1d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Conv1D import and testing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small keras test fix Signed-off-by: Alex Black <blacka101@gmail.com> * Don't allow setting causal convolution mode to conv2d/3d layers Signed-off-by: Alex Black <blacka101@gmail.com> * More robustly infer nIn for recurrent layers for ambiguous NCW and NWC cases Signed-off-by: Alex Black <blacka101@gmail.com> * Polish and cleanup Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
578a5abb68
commit
9cc8803b8d
|
@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.util.Convolution1DUtils;
|
||||||
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1Causal() {
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 3;
|
||||||
|
|
||||||
|
int[] lengths = {11, 12, 13, 9, 10, 11};
|
||||||
|
int[] kernels = {2, 3, 2, 4, 2, 3};
|
||||||
|
int[] dilations = {1, 1, 2, 1, 2, 1};
|
||||||
|
int[] strides = {1, 2, 1, 2, 1, 1};
|
||||||
|
boolean[] masks = {false, true, false, true, false, true};
|
||||||
|
boolean[] hasB = {true, false, true, false, true, true};
|
||||||
|
|
||||||
|
for (int i = 0; i < lengths.length; i++) {
|
||||||
|
int length = lengths[i];
|
||||||
|
int k = kernels[i];
|
||||||
|
int d = dilations[i];
|
||||||
|
int st = strides[i];
|
||||||
|
boolean mask = masks[i];
|
||||||
|
boolean hasBias = hasB[i];
|
||||||
|
//TODO has bias
|
||||||
|
String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length;
|
||||||
|
log.info("Starting test: " + s);
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(new Convolution1DLayer.Builder().kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.hasBias(hasBias)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st).nIn(convNIn).nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(new Convolution1DLayer.Builder().kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st).nIn(convNOut1).nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
|
||||||
|
.setInputType(InputType.recurrent(convNIn, length)).build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
|
||||||
|
INDArray fm = null;
|
||||||
|
if (mask) {
|
||||||
|
fm = Nd4j.create(2, length);
|
||||||
|
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
|
||||||
|
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
|
||||||
|
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2);
|
||||||
|
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null);
|
||||||
|
|
||||||
|
assertTrue(s, gradOK);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
|
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv1dCausalAllowed(){
|
||||||
|
new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
|
new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv2dNoCausalAllowed(){
|
||||||
|
|
||||||
|
try{
|
||||||
|
new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv3dNoCausalAllowed(){
|
||||||
|
try{
|
||||||
|
new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -356,6 +356,10 @@ public class KerasLayer {
|
||||||
return this.layer;
|
return this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setLayer(Layer layer){
|
||||||
|
this.layer = layer;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether this Keras layer maps to a DL4J Vertex.
|
* Whether this Keras layer maps to a DL4J Vertex.
|
||||||
*
|
*
|
||||||
|
|
|
@ -233,6 +233,7 @@ public class KerasLayerConfiguration {
|
||||||
private final String LAYER_BORDER_MODE_SAME = "same";
|
private final String LAYER_BORDER_MODE_SAME = "same";
|
||||||
private final String LAYER_BORDER_MODE_VALID = "valid";
|
private final String LAYER_BORDER_MODE_VALID = "valid";
|
||||||
private final String LAYER_BORDER_MODE_FULL = "full";
|
private final String LAYER_BORDER_MODE_FULL = "full";
|
||||||
|
private final String LAYER_BORDER_MODE_CAUSAL = "causal";
|
||||||
|
|
||||||
/* Noise layers */
|
/* Noise layers */
|
||||||
private final String LAYER_FIELD_RATE = "rate";
|
private final String LAYER_FIELD_RATE = "rate";
|
||||||
|
|
|
@ -124,7 +124,26 @@ public class KerasInput extends KerasLayer {
|
||||||
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
|
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
|
if(this.dimOrder != null) {
|
||||||
|
switch (this.dimOrder) {
|
||||||
|
case TENSORFLOW: //NWC == channels_last
|
||||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
||||||
|
break;
|
||||||
|
case THEANO: //NCW == channels_first
|
||||||
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||||
|
break;
|
||||||
|
case NONE:
|
||||||
|
//Assume RNN in [mb, seqLen, size] format
|
||||||
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//Assume RNN in [mb, seqLen, size] format
|
||||||
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
switch (this.dimOrder) {
|
switch (this.dimOrder) {
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer {
|
||||||
*/
|
*/
|
||||||
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
|
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
|
||||||
if (type instanceof InputType.InputTypeFeedForward) {
|
if (type instanceof InputType.InputTypeFeedForward) {
|
||||||
this.layer = new LossLayer.Builder(loss).name(this.layerName).build();
|
this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||||
}
|
}
|
||||||
else if (type instanceof InputType.InputTypeRecurrent) {
|
else if (type instanceof InputType.InputTypeRecurrent) {
|
||||||
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build();
|
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||||
}
|
}
|
||||||
else if (type instanceof InputType.InputTypeConvolutional) {
|
else if (type instanceof InputType.InputTypeConvolutional) {
|
||||||
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build();
|
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
|
throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
|
||||||
+ "got : " + type.toString());
|
+ "got : " + type.toString());
|
||||||
|
|
|
@ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer {
|
||||||
public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
|
public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
super(layerConfig, enforceTrainingConfig);
|
super(layerConfig, enforceTrainingConfig);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case THEANO:
|
case THEANO:
|
||||||
paramValue = kerasParamValue.permute(2, 1, 0);
|
//Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1]
|
||||||
paramValue = paramValue.reshape(
|
long k = kerasParamValue.size(0);
|
||||||
paramValue.size(0), paramValue.size(1),
|
long nIn = kerasParamValue.size(1);
|
||||||
paramValue.size(2), 1).dup();
|
long nOut = kerasParamValue.size(2);
|
||||||
for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) {
|
paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1);
|
||||||
INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup();
|
|
||||||
double[] flattenedFilter = copyFilter.ravel().data().asDouble();
|
|
||||||
ArrayUtils.reverse(flattenedFilter);
|
|
||||||
INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape());
|
|
||||||
INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3);
|
|
||||||
inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType()));
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());
|
throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());
|
||||||
|
|
|
@ -264,7 +264,8 @@ public class KerasConvolutionUtils {
|
||||||
} else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
|
} else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
|
||||||
borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
|
borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
|
||||||
convolutionMode = ConvolutionMode.Truncate;
|
convolutionMode = ConvolutionMode.Truncate;
|
||||||
|
} else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) {
|
||||||
|
convolutionMode = ConvolutionMode.Causal;
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
|
throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,11 +23,13 @@ import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer {
|
||||||
.biasInit(0.0) // TODO: this is incorrect
|
.biasInit(0.0) // TODO: this is incorrect
|
||||||
.l1(this.weightL1Regularization)
|
.l1(this.weightL1Regularization)
|
||||||
.l2(this.weightL2Regularization);
|
.l2(this.weightL2Regularization);
|
||||||
|
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||||
|
if(nIn != null)
|
||||||
|
builder.setNIn(nIn);
|
||||||
if (biasConstraint != null)
|
if (biasConstraint != null)
|
||||||
builder.constrainBias(biasConstraint);
|
builder.constrainBias(biasConstraint);
|
||||||
if (weightConstraint != null)
|
if (weightConstraint != null)
|
||||||
|
@ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer {
|
||||||
log.warn("Attemping to set weights for unknown parameters: "
|
log.warn("Attemping to set weights for unknown parameters: "
|
||||||
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FeedForwardLayer ffl;
|
||||||
|
if(this.layer instanceof BaseWrapperLayer){
|
||||||
|
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
|
||||||
|
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
||||||
|
} else {
|
||||||
|
ffl = (FeedForwardLayer) this.layer;
|
||||||
|
}
|
||||||
|
if(ffl.getNIn() != wRows){
|
||||||
|
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
|
||||||
|
//We can reliably infer nIn from the shape of the weights array however
|
||||||
|
ffl.setNIn(wRows);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer {
|
||||||
.biasInit(0.0)
|
.biasInit(0.0)
|
||||||
.l1(this.weightL1Regularization)
|
.l1(this.weightL1Regularization)
|
||||||
.l2(this.weightL2Regularization);
|
.l2(this.weightL2Regularization);
|
||||||
|
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||||
|
if(nIn != null)
|
||||||
|
builder.setNIn(nIn);
|
||||||
if (biasConstraint != null)
|
if (biasConstraint != null)
|
||||||
builder.constrainBias(biasConstraint);
|
builder.constrainBias(biasConstraint);
|
||||||
if (weightConstraint != null)
|
if (weightConstraint != null)
|
||||||
|
@ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer {
|
||||||
log.warn("Attemping to set weights for unknown parameters: "
|
log.warn("Attemping to set weights for unknown parameters: "
|
||||||
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FeedForwardLayer ffl;
|
||||||
|
if(this.layer instanceof BaseWrapperLayer){
|
||||||
|
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
|
||||||
|
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
||||||
|
} else {
|
||||||
|
ffl = (FeedForwardLayer) this.layer;
|
||||||
|
}
|
||||||
|
if(ffl.getNIn() != W.rows()){
|
||||||
|
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
|
||||||
|
//We can reliably infer nIn from the shape of the weights array however
|
||||||
|
ffl.setNIn(W.rows());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
@Override
|
@Override
|
||||||
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
|
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
|
||||||
|
|
||||||
Map<String, INDArray> forwardWeights = getUnderlyingWeights(weights, "forward");
|
Map<String, INDArray> forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward");
|
||||||
Map<String, INDArray> backwardWeights = getUnderlyingWeights(weights, "backward");
|
Map<String, INDArray> backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward");
|
||||||
|
|
||||||
this.weights = new HashMap<>();
|
this.weights = new HashMap<>();
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private Map<String, INDArray> getUnderlyingWeights(Map<String, INDArray> weights, String direction)
|
private Map<String, INDArray> getUnderlyingWeights(Layer l, Map<String, INDArray> weights, String direction)
|
||||||
throws InvalidKerasConfigurationException {
|
throws InvalidKerasConfigurationException {
|
||||||
int keras1SubstringLength;
|
int keras1SubstringLength;
|
||||||
if (kerasRnnlayer instanceof KerasLSTM)
|
if (kerasRnnlayer instanceof KerasLSTM)
|
||||||
|
@ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
weights = newWeights;
|
weights = newWeights;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Layer layerBefore = kerasRnnlayer.getLayer();
|
||||||
|
kerasRnnlayer.setLayer(l);
|
||||||
kerasRnnlayer.setWeights(weights);
|
kerasRnnlayer.setWeights(weights);
|
||||||
return kerasRnnlayer.getWeights();
|
Map<String,INDArray> ret = kerasRnnlayer.getWeights();
|
||||||
|
kerasRnnlayer.setLayer(layerBefore);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -505,6 +505,17 @@ public class KerasLayerUtils {
|
||||||
return nOut;
|
return nOut;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
|
||||||
|
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||||
|
if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){
|
||||||
|
Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM());
|
||||||
|
if(id instanceof Number){
|
||||||
|
return ((Number)id).intValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get dropout from Keras layer configuration.
|
* Get dropout from Keras layer configuration.
|
||||||
*
|
*
|
||||||
|
|
|
@ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass;
|
||||||
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
|
@ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.function.BiFunction;
|
||||||
|
import org.nd4j.linalg.function.Function;
|
||||||
import org.nd4j.linalg.learning.config.NoOp;
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
||||||
|
@ -58,10 +62,7 @@ import java.io.InputStream;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.StandardCopyOption;
|
import java.nio.file.StandardCopyOption;
|
||||||
import java.util.HashMap;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@ -86,6 +87,15 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
@Rule
|
@Rule
|
||||||
public final TemporaryFolder testDir = new TemporaryFolder();
|
public final TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(String s, INDArray array) {
|
||||||
|
if(array.rank() == 3)
|
||||||
|
return array.permute(0, 2, 1); //NWC to NCW
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test(expected = IllegalStateException.class)
|
||||||
public void fileNotFoundEndToEnd() throws Exception {
|
public void fileNotFoundEndToEnd() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||||
|
@ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importImdbLstmTfKeras1() throws Exception {
|
public void importImdbLstmTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmThKeras1() throws Exception {
|
public void importImdbLstmThKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmTfKeras2() throws Exception {
|
public void importImdbLstmTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmThKeras2() throws Exception {
|
public void importImdbLstmThKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
||||||
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
model.summary();
|
model.summary();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCausalCon1D() throws Exception {
|
||||||
|
String[] names = new String[]{
|
||||||
|
"causal_conv1d_k2_s1_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k2_s1_d2_cl_model.h5",
|
||||||
|
"causal_conv1d_k2_s2_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k2_s3_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s1_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s1_d2_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s2_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s3_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s1_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s1_d2_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s2_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s3_d1_cl_model.h5"
|
||||||
|
};
|
||||||
|
|
||||||
|
for(String name : names ){
|
||||||
|
System.out.println("Starting test: " + name);
|
||||||
|
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
|
||||||
|
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||||
|
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(INDArray i) {
|
||||||
|
//NWC to NCW
|
||||||
|
return i.permute(0, 2, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||||
|
true, true, false, f, nwc2ncwExpected);
|
||||||
|
Layer l = net.getLayer(0);
|
||||||
|
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
|
||||||
|
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCon1D() throws Exception {
|
||||||
|
String[] names = new String[]{
|
||||||
|
"conv1d_k2_s1_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s1_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cl_valid_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cl_valid_model.h5",
|
||||||
|
};
|
||||||
|
|
||||||
|
for(String name : names ){
|
||||||
|
System.out.println("Starting test: " + name);
|
||||||
|
String modelPath = "modelimport/keras/examples/conv1d/" + name;
|
||||||
|
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||||
|
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(INDArray i) {
|
||||||
|
//NWC to NCW
|
||||||
|
return i.permute(0, 2, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(String s, INDArray array) {
|
||||||
|
// if("conv".equals(s)){
|
||||||
|
return array.permute(0, 2, 1);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||||
|
true, true, false, f, f2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
|
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
|
||||||
return importFunctionalModelH5Test(modelPath, null, false);
|
return importFunctionalModelH5Test(modelPath, null, false);
|
||||||
}
|
}
|
||||||
|
@ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
|
|
||||||
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
||||||
boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
|
boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
|
||||||
|
return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
||||||
|
boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function<INDArray,INDArray> inputPreProc,
|
||||||
|
BiFunction<String,INDArray,INDArray> expectedPreProc) throws Exception {
|
||||||
MultiLayerNetwork model;
|
MultiLayerNetwork model;
|
||||||
try(InputStream is = Resources.asStream(modelPath)) {
|
try(InputStream is = Resources.asStream(modelPath)) {
|
||||||
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
|
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
|
||||||
|
@ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (checkPredictions) {
|
if (checkPredictions) {
|
||||||
INDArray input = getInputs(outputsArchive, tfOrdering)[0];
|
INDArray input = getInputs(outputsArchive, tfOrdering)[0];
|
||||||
|
if(inputPreProc != null)
|
||||||
|
input = inputPreProc.apply(input);
|
||||||
|
|
||||||
Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
|
Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
|
||||||
for (int i = 0; i < model.getLayers().length; i++) {
|
for (int i = 0; i < model.getLayers().length; i++) {
|
||||||
String layerName = model.getLayerNames().get(i);
|
String layerName = model.getLayerNames().get(i);
|
||||||
if (activationsKeras.containsKey(layerName)) {
|
if (activationsKeras.containsKey(layerName)) {
|
||||||
INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
|
INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
|
||||||
if (activationsDl4j.shape().length == 3)
|
INDArray exp = activationsKeras.get(layerName);
|
||||||
activationsDl4j = activationsDl4j.permute(0, 2, 1);
|
if(expectedPreProc != null)
|
||||||
compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS);
|
exp = expectedPreProc.apply(layerName, exp);
|
||||||
|
compareINDArrays(layerName, exp, activationsDl4j, EPS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
|
INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
|
||||||
INDArray predictionsDl4j = model.output(input, false);
|
INDArray predictionsDl4j = model.output(input, false);
|
||||||
|
if(expectedPreProc != null)
|
||||||
|
predictionsKeras = expectedPreProc.apply("output", predictionsKeras);
|
||||||
compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
|
compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
|
||||||
INDArray outputs = getOutputs(outputsArchive, true)[0];
|
INDArray outputs = getOutputs(outputsArchive, true)[0];
|
||||||
|
|
||||||
|
@ -680,6 +817,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
val nOut = (int) outputs.size(-1);
|
val nOut = (int) outputs.size(-1);
|
||||||
|
|
||||||
|
if(checkAuc)
|
||||||
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
|
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
return predictions;
|
return predictions;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) {
|
private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) {
|
||||||
INDArray diff = a.sub(b.castTo(a.dataType()));
|
if(!expected.equalShapes(actual)){
|
||||||
|
throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape()));
|
||||||
|
}
|
||||||
|
INDArray diff = expected.sub(actual.castTo(expected.dataType()));
|
||||||
double min = diff.minNumber().doubleValue();
|
double min = diff.minNumber().doubleValue();
|
||||||
double max = diff.maxNumber().doubleValue();
|
double max = diff.maxNumber().doubleValue();
|
||||||
log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max);
|
log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max);
|
||||||
double threshold = 1e-7;
|
double threshold = 1e-7;
|
||||||
double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue()));
|
double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue()));
|
||||||
double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue()));
|
double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue()));
|
||||||
|
|
||||||
// skip too small absolute inputs
|
// skip too small absolute inputs
|
||||||
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
|
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
|
||||||
assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps));
|
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
|
||||||
|
assertTrue("Output differs: " + label, eq);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,
|
private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,
|
||||||
|
|
|
@ -69,6 +69,18 @@ package org.deeplearning4j.nn.conf;
|
||||||
* <br>
|
* <br>
|
||||||
* <br>
|
* <br>
|
||||||
* <br>
|
* <br>
|
||||||
|
* <b>Causal</b>: Causal padding mode can only be used for 1D convolutional neural networks.<br>
|
||||||
|
* The motivation behind causal padding mode is that the output time steps depend only on current and past time steps.<br>
|
||||||
|
* That is, out[t] (for time t) depends on only on values in[T] for t < T<br>
|
||||||
|
* The output size of 1D convolution/subsampling layers is the same as with SAME convolution mode -
|
||||||
|
* i.e., outSize = ceil( inputSize / stride )<br>
|
||||||
|
* Padding is also the same as SAME mode, but all padding in on the left (start of sequence) instead of being on both
|
||||||
|
* left and right of the input<br>
|
||||||
|
* For more details on causal convolutions, see <a href="https://arxiv.org/abs/1609.03499">WaveNet: A Generative Model For Audio</a>,
|
||||||
|
* section 2.1.
|
||||||
|
* <br>
|
||||||
|
* <br>
|
||||||
|
* <br>
|
||||||
* For further information on output sizes for convolutional neural networks, see the "Spatial arrangement" section at
|
* For further information on output sizes for convolutional neural networks, see the "Spatial arrangement" section at
|
||||||
* <a href="http://cs231n.github.io/convolutional-networks/">http://cs231n.github.io/convolutional-networks/</a>
|
* <a href="http://cs231n.github.io/convolutional-networks/">http://cs231n.github.io/convolutional-networks/</a>
|
||||||
*
|
*
|
||||||
|
@ -76,6 +88,6 @@ package org.deeplearning4j.nn.conf;
|
||||||
*/
|
*/
|
||||||
public enum ConvolutionMode {
|
public enum ConvolutionMode {
|
||||||
|
|
||||||
Strict, Truncate, Same
|
Strict, Truncate, Same, Causal
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,6 +124,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
this.setKernelSize((int[]) null);
|
this.setKernelSize((int[]) null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param kernelSize Kernel size
|
* @param kernelSize Kernel size
|
||||||
* @param stride Stride
|
* @param stride Stride
|
||||||
|
|
|
@ -163,6 +163,12 @@ public class Convolution3D extends ConvolutionLayer {
|
||||||
super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3);
|
super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
//Causal convolution - allowed for 1D only
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
|
public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
|
||||||
super(kernelSize, stride, padding, dilation, 3);
|
super(kernelSize, stride, padding, dilation, 3);
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
import org.deeplearning4j.util.ValidationUtils;
|
import org.deeplearning4j.util.ValidationUtils;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -283,6 +284,12 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
//Causal convolution - allowed for 1D only
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Size of the convolution rows/columns
|
* Size of the convolution rows/columns
|
||||||
*
|
*
|
||||||
|
@ -456,6 +463,14 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
|
|
||||||
protected BaseConvBuilder() {}
|
protected BaseConvBuilder() {}
|
||||||
|
|
||||||
|
protected abstract boolean allowCausal();
|
||||||
|
|
||||||
|
protected void setConvolutionMode(ConvolutionMode convolutionMode){
|
||||||
|
Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
||||||
|
" convolutional neural network layers");
|
||||||
|
this.convolutionMode = convolutionMode;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If true (default): include bias parameters in the model. False: no bias.
|
* If true (default): include bias parameters in the model. False: no bias.
|
||||||
*
|
*
|
||||||
|
|
|
@ -133,6 +133,12 @@ public class Deconvolution2D extends ConvolutionLayer {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
//Causal convolution - allowed for 1D only
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
||||||
*
|
*
|
||||||
|
|
|
@ -133,6 +133,12 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
//Causal convolution - allowed for 1D only
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set channels multiplier for depth-wise convolution
|
* Set channels multiplier for depth-wise convolution
|
||||||
*
|
*
|
||||||
|
|
|
@ -184,6 +184,12 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
//Causal convolution - allowed for 1D only
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set channels multiplier of channels-wise step in separable convolution
|
* Set channels multiplier of channels-wise step in separable convolution
|
||||||
*
|
*
|
||||||
|
|
|
@ -167,6 +167,11 @@ public class Subsampling1DLayer extends SubsamplingLayer {
|
||||||
this(poolingType, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
|
this(poolingType, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
public Builder() {
|
public Builder() {
|
||||||
this(DEFAULT_POOLING, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
|
this(DEFAULT_POOLING, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING);
|
||||||
}
|
}
|
||||||
|
|
|
@ -431,6 +431,12 @@ public class Subsampling3DLayer extends NoParamLayer {
|
||||||
this.setPoolingType(poolingType);
|
this.setPoolingType(poolingType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void setConvolutionMode(ConvolutionMode convolutionMode){
|
||||||
|
Preconditions.checkState(convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
||||||
|
" convolutional neural network layers");
|
||||||
|
this.convolutionMode = convolutionMode;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
||||||
*
|
*
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
import org.deeplearning4j.util.ValidationUtils;
|
import org.deeplearning4j.util.ValidationUtils;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -270,6 +271,12 @@ public class SubsamplingLayer extends NoParamLayer {
|
||||||
super(poolingType);
|
super(poolingType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean allowCausal() {
|
||||||
|
//Only conv1d/subsampling1d can use causal mode
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Kernel size
|
* Kernel size
|
||||||
*
|
*
|
||||||
|
@ -449,6 +456,14 @@ public class SubsamplingLayer extends NoParamLayer {
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected abstract boolean allowCausal();
|
||||||
|
|
||||||
|
public void setConvolutionMode(ConvolutionMode convolutionMode){
|
||||||
|
Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
||||||
|
" convolutional neural network layers");
|
||||||
|
this.convolutionMode = convolutionMode;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
||||||
*
|
*
|
||||||
|
|
|
@ -18,18 +18,30 @@ package org.deeplearning4j.nn.layers.convolution;
|
||||||
|
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||||
|
import org.deeplearning4j.util.Convolution1DUtils;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.factory.Broadcast;
|
import org.nd4j.linalg.factory.Broadcast;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 1D (temporal) convolutional layer. Currently, we just subclass off the
|
* 1D (temporal) convolutional layer. Currently, we just subclass off the
|
||||||
|
@ -70,6 +82,52 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
Broadcast.mul(epsilon, maskOut, epsilon, 0, 2);
|
Broadcast.mul(epsilon, maskOut, epsilon, 0, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){
|
||||||
|
Pair<INDArray,INDArray> fwd = causalConv1dForward();
|
||||||
|
IActivation afn = layerConf().getActivationFn();
|
||||||
|
INDArray delta = afn.backprop(fwd.getFirst(), epsilon).getFirst(); //TODO handle activation function params
|
||||||
|
|
||||||
|
//TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support
|
||||||
|
org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf();
|
||||||
|
Conv1DConfig conf = Conv1DConfig.builder()
|
||||||
|
.k(c.getKernelSize()[0])
|
||||||
|
.s(c.getStride()[0])
|
||||||
|
.d(c.getDilation()[0])
|
||||||
|
.p(c.getPadding()[0])
|
||||||
|
.dataFormat(Conv1DConfig.NCW)
|
||||||
|
.paddingMode(PaddingMode.CAUSAL)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
|
w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC]
|
||||||
|
|
||||||
|
INDArray[] inputArrs;
|
||||||
|
INDArray[] outputArrs;
|
||||||
|
INDArray wg = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
|
wg = wg.reshape(wg.ordering(), wg.size(0), wg.size(1), wg.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] -> [kW, iC, oC]
|
||||||
|
INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
||||||
|
if(layerConf().hasBias()){
|
||||||
|
INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
|
b = b.reshape(b.length());
|
||||||
|
inputArrs = new INDArray[]{input.castTo(w.dataType()), w, b, delta};
|
||||||
|
INDArray bg = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
|
bg = bg.reshape(bg.length());
|
||||||
|
outputArrs = new INDArray[]{epsOut, wg, bg};
|
||||||
|
} else {
|
||||||
|
inputArrs = new INDArray[]{input.castTo(w.dataType()), w, delta};
|
||||||
|
outputArrs = new INDArray[]{epsOut, wg};
|
||||||
|
}
|
||||||
|
Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf);
|
||||||
|
Nd4j.exec(op);
|
||||||
|
|
||||||
|
Gradient retGradient = new DefaultGradient();
|
||||||
|
if(layerConf().hasBias()){
|
||||||
|
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY));
|
||||||
|
}
|
||||||
|
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c');
|
||||||
|
return new Pair<>(retGradient, epsOut);
|
||||||
|
}
|
||||||
|
|
||||||
// add singleton fourth dimension to input and next layer's epsilon
|
// add singleton fourth dimension to input and next layer's epsilon
|
||||||
epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1);
|
epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1);
|
||||||
INDArray origInput = input;
|
INDArray origInput = input;
|
||||||
|
@ -98,6 +156,12 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
@Override
|
@Override
|
||||||
protected Pair<INDArray,INDArray> preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
|
protected Pair<INDArray,INDArray> preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
|
||||||
assertInputSet(false);
|
assertInputSet(false);
|
||||||
|
|
||||||
|
if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){
|
||||||
|
return causalConv1dForward();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
INDArray origInput = input;
|
INDArray origInput = input;
|
||||||
input = input.reshape(input.size(0), input.size(1), input.size(2), 1);
|
input = input.reshape(input.size(0), input.size(1), input.size(2), 1);
|
||||||
|
|
||||||
|
@ -113,6 +177,36 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
return preOutput;
|
return preOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected Pair<INDArray,INDArray> causalConv1dForward(){
|
||||||
|
//TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support
|
||||||
|
org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf();
|
||||||
|
Conv1DConfig conf = Conv1DConfig.builder()
|
||||||
|
.k(c.getKernelSize()[0])
|
||||||
|
.s(c.getStride()[0])
|
||||||
|
.d(c.getDilation()[0])
|
||||||
|
.p(c.getPadding()[0])
|
||||||
|
.dataFormat(Conv1DConfig.NCW)
|
||||||
|
.paddingMode(PaddingMode.CAUSAL)
|
||||||
|
.build();
|
||||||
|
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
|
w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC]
|
||||||
|
|
||||||
|
INDArray[] inputs;
|
||||||
|
if(layerConf().hasBias()){
|
||||||
|
INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
|
b = b.reshape(b.length());
|
||||||
|
inputs = new INDArray[]{input.castTo(w.dataType()), w, b};
|
||||||
|
} else {
|
||||||
|
inputs = new INDArray[]{input.castTo(w.dataType()), w};
|
||||||
|
}
|
||||||
|
|
||||||
|
Conv1D op = new Conv1D(inputs, null, conf);
|
||||||
|
List<LongShapeDescriptor> outShape = op.calculateOutputShape();
|
||||||
|
op.setOutputArgument(0, Nd4j.create(outShape.get(0), false));
|
||||||
|
Nd4j.exec(op);
|
||||||
|
return new Pair<>(op.getOutputArgument(0), null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
|
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
|
||||||
INDArray act4d = super.activate(training, workspaceMgr);
|
INDArray act4d = super.activate(training, workspaceMgr);
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class Convolution1DUtils {
|
||||||
public static long getOutputSize(long inH, int kernel, int strides, int padding,
|
public static long getOutputSize(long inH, int kernel, int strides, int padding,
|
||||||
ConvolutionMode convolutionMode, int dilation) {
|
ConvolutionMode convolutionMode, int dilation) {
|
||||||
long eKernel = effectiveKernelSize(kernel, dilation);
|
long eKernel = effectiveKernelSize(kernel, dilation);
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
|
||||||
return (int) Math.ceil(inH / ((double) strides));
|
return (int) Math.ceil(inH / ((double) strides));
|
||||||
}
|
}
|
||||||
return (inH - eKernel + 2 * padding) / strides + 1;
|
return (inH - eKernel + 2 * padding) / strides + 1;
|
||||||
|
@ -92,7 +92,7 @@ public class Convolution1DUtils {
|
||||||
boolean atrous = (eKernel == kernel);
|
boolean atrous = (eKernel == kernel);
|
||||||
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inH, atrous);
|
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inH, atrous);
|
||||||
|
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
|
||||||
int outH = (int) Math.ceil(inH / ((double) strides));
|
int outH = (int) Math.ceil(inH / ((double) strides));
|
||||||
return outH;
|
return outH;
|
||||||
}
|
}
|
||||||
|
@ -106,8 +106,9 @@ public class Convolution1DUtils {
|
||||||
boolean atrous) {
|
boolean atrous) {
|
||||||
|
|
||||||
int inH = inShape;
|
int inH = inShape;
|
||||||
|
boolean t = convolutionMode == ConvolutionMode.Truncate;
|
||||||
|
|
||||||
if (convolutionMode != ConvolutionMode.Same && (eKernel <= 0 || eKernel > inH + 2 * padding)) {
|
if (t && (eKernel <= 0 || eKernel > inH + 2 * padding)) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("Invalid input data or configuration: ");
|
sb.append("Invalid input data or configuration: ");
|
||||||
if (atrous) sb.append("effective ");
|
if (atrous) sb.append("effective ");
|
||||||
|
|
|
@ -121,7 +121,7 @@ public class ConvolutionUtils {
|
||||||
int[] inShape = new int[]{inH, inW};
|
int[] inShape = new int[]{inH, inW};
|
||||||
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
|
validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
|
||||||
|
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
|
||||||
|
|
||||||
int outH = (int) Math.ceil(inH / ((double) strides[0]));
|
int outH = (int) Math.ceil(inH / ((double) strides[0]));
|
||||||
int outW = (int) Math.ceil(inW / ((double) strides[1]));
|
int outW = (int) Math.ceil(inW / ((double) strides[1]));
|
||||||
|
@ -142,7 +142,9 @@ public class ConvolutionUtils {
|
||||||
int inH = inShape[0];
|
int inH = inShape[0];
|
||||||
int inW = inShape[1];
|
int inW = inShape[1];
|
||||||
|
|
||||||
if (convolutionMode != ConvolutionMode.Same && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) {
|
boolean t = (convolutionMode == ConvolutionMode.Truncate);
|
||||||
|
|
||||||
|
if (t && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("Invalid input data or configuration: ");
|
sb.append("Invalid input data or configuration: ");
|
||||||
if (atrous) sb.append("effective ");
|
if (atrous) sb.append("effective ");
|
||||||
|
@ -158,7 +160,7 @@ public class ConvolutionUtils {
|
||||||
throw new DL4JInvalidInputException(sb.toString());
|
throw new DL4JInvalidInputException(sb.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (convolutionMode != ConvolutionMode.Same && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) {
|
if (t && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("Invalid input data or configuration: ");
|
sb.append("Invalid input data or configuration: ");
|
||||||
if (atrous) sb.append("effective ");
|
if (atrous) sb.append("effective ");
|
||||||
|
@ -175,8 +177,7 @@ public class ConvolutionUtils {
|
||||||
throw new DL4JInvalidInputException(sb.toString());
|
throw new DL4JInvalidInputException(sb.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (eKernel.length == 3 && convolutionMode != ConvolutionMode.Same
|
if (eKernel.length == 3 && t && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) {
|
||||||
&& (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) {
|
|
||||||
int inD = inShape[2];
|
int inD = inShape[2];
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("Invalid input data or configuration: ");
|
sb.append("Invalid input data or configuration: ");
|
||||||
|
@ -615,7 +616,7 @@ public class ConvolutionUtils {
|
||||||
*/
|
*/
|
||||||
public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm){
|
public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm){
|
||||||
Preconditions.checkState(in.rank()==2, "Rank must be 2 for cnn1d mask array - shape ", in.shape());
|
Preconditions.checkState(in.rank()==2, "Rank must be 2 for cnn1d mask array - shape ", in.shape());
|
||||||
if(cm == ConvolutionMode.Same && stride == 1 ){
|
if((cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) && stride == 1 ){
|
||||||
return in;
|
return in;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -630,7 +631,7 @@ public class ConvolutionUtils {
|
||||||
int[] k = new int[]{kernel,1};
|
int[] k = new int[]{kernel,1};
|
||||||
int[] s = new int[]{stride, 1};
|
int[] s = new int[]{stride, 1};
|
||||||
int[] d = new int[]{dilation, 1};
|
int[] d = new int[]{dilation, 1};
|
||||||
if (cm == ConvolutionMode.Same) {
|
if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation
|
outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation
|
||||||
} else {
|
} else {
|
||||||
pad = new int[]{padding, 0};
|
pad = new int[]{padding, 0};
|
||||||
|
@ -645,7 +646,7 @@ public class ConvolutionUtils {
|
||||||
.sH(s[0]).sW(s[1])
|
.sH(s[0]).sW(s[1])
|
||||||
.pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1])
|
.pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1])
|
||||||
.dH(d[0]).dW(d[1])
|
.dH(d[0]).dW(d[1])
|
||||||
.isSameMode(cm== ConvolutionMode.Same)
|
.isSameMode(cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal)
|
||||||
.isNHWC(false)
|
.isNHWC(false)
|
||||||
.build());
|
.build());
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue