DL4J + Keras import: Causal Conv1D support (#107)
* Keras causal conv1d support first steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Causal conv mode Signed-off-by: AlexDBlack <blacka101@gmail.com> * Gradient check and fixes for causal conv1d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Conv1D import and testing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small keras test fix Signed-off-by: Alex Black <blacka101@gmail.com> * Don't allow setting causal convolution mode to conv2d/3d layers Signed-off-by: Alex Black <blacka101@gmail.com> * More robustly infer nIn for recurrent layers for ambiguous NCW and NWC cases Signed-off-by: Alex Black <blacka101@gmail.com> * Polish and cleanup Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									578a5abb68
								
							
						
					
					
						commit
						9cc8803b8d
					
				| @ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType; | ||||
| import org.deeplearning4j.nn.conf.layers.*; | ||||
| import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; | ||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.deeplearning4j.util.Convolution1DUtils; | ||||
| import org.deeplearning4j.util.ConvolutionUtils; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.activations.Activation; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| @ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCnn1Causal() { | ||||
|         int convNIn = 2; | ||||
|         int convNOut1 = 3; | ||||
|         int convNOut2 = 4; | ||||
|         int finalNOut = 3; | ||||
| 
 | ||||
|         int[] lengths   =  {11, 12, 13,  9, 10, 11}; | ||||
|         int[] kernels   =  {2,   3,  2,  4,  2,  3}; | ||||
|         int[] dilations =  {1,   1,  2,  1,  2,  1}; | ||||
|         int[] strides   =  {1,   2,  1,  2,  1,  1}; | ||||
|         boolean[] masks =  {false, true, false, true, false, true}; | ||||
|         boolean[] hasB  =  {true, false, true, false, true, true}; | ||||
| 
 | ||||
|         for (int i = 0; i < lengths.length; i++) { | ||||
|             int length = lengths[i]; | ||||
|             int k = kernels[i]; | ||||
|             int d = dilations[i]; | ||||
|             int st = strides[i]; | ||||
|             boolean mask = masks[i]; | ||||
|             boolean hasBias = hasB[i]; | ||||
|             //TODO has bias | ||||
|             String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length; | ||||
|             log.info("Starting test: " + s); | ||||
|             Nd4j.getRandom().setSeed(12345); | ||||
| 
 | ||||
|             MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                     .dataType(DataType.DOUBLE) | ||||
|                     .updater(new NoOp()) | ||||
|                     .activation(Activation.TANH) | ||||
|                     .weightInit(new NormalDistribution(0, 1)) | ||||
|                     .seed(12345) | ||||
|                     .list() | ||||
|                     .layer(new Convolution1DLayer.Builder().kernelSize(k) | ||||
|                             .dilation(d) | ||||
|                             .hasBias(hasBias) | ||||
|                             .convolutionMode(ConvolutionMode.Causal) | ||||
|                             .stride(st).nIn(convNIn).nOut(convNOut1) | ||||
|                             .build()) | ||||
|                     .layer(new Convolution1DLayer.Builder().kernelSize(k) | ||||
|                             .dilation(d) | ||||
|                             .convolutionMode(ConvolutionMode.Causal) | ||||
|                             .stride(st).nIn(convNOut1).nOut(convNOut2) | ||||
|                             .build()) | ||||
|                     .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) | ||||
|                             .activation(Activation.SOFTMAX).nOut(finalNOut).build()) | ||||
|                     .setInputType(InputType.recurrent(convNIn, length)).build(); | ||||
| 
 | ||||
|             MultiLayerNetwork net = new MultiLayerNetwork(conf); | ||||
|             net.init(); | ||||
| 
 | ||||
|             INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length); | ||||
|             INDArray fm = null; | ||||
|             if (mask) { | ||||
|                 fm = Nd4j.create(2, length); | ||||
|                 fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); | ||||
|                 fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1); | ||||
|             } | ||||
| 
 | ||||
|             long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d); | ||||
|             long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d); | ||||
| 
 | ||||
|             INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2); | ||||
| 
 | ||||
|             boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, | ||||
|                     DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null); | ||||
| 
 | ||||
|             assertTrue(s, gradOK); | ||||
|             TestUtils.testModelSerialization(net); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest { | ||||
|             assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels")); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testConv1dCausalAllowed(){ | ||||
|         new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); | ||||
|         new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testConv2dNoCausalAllowed(){ | ||||
| 
 | ||||
|         try{ | ||||
|             new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); | ||||
|             fail("Expected exception"); | ||||
|         } catch (Throwable t){ | ||||
|             String m = t.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("causal") && m.contains("1d")); | ||||
|         } | ||||
| 
 | ||||
|         try{ | ||||
|             new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); | ||||
|             fail("Expected exception"); | ||||
|         } catch (Throwable t){ | ||||
|             String m = t.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("causal") && m.contains("1d")); | ||||
|         } | ||||
| 
 | ||||
|         try{ | ||||
|             new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); | ||||
|             fail("Expected exception"); | ||||
|         } catch (Throwable t){ | ||||
|             String m = t.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("causal") && m.contains("1d")); | ||||
|         } | ||||
| 
 | ||||
|         try{ | ||||
|             new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); | ||||
|             fail("Expected exception"); | ||||
|         } catch (Throwable t){ | ||||
|             String m = t.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("causal") && m.contains("1d")); | ||||
|         } | ||||
| 
 | ||||
|         try{ | ||||
|             new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); | ||||
|             fail("Expected exception"); | ||||
|         } catch (Throwable t){ | ||||
|             String m = t.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("causal") && m.contains("1d")); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testConv3dNoCausalAllowed(){ | ||||
|         try{ | ||||
|             new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build(); | ||||
|             fail("Expected exception"); | ||||
|         } catch (Throwable t){ | ||||
|             String m = t.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("causal") && m.contains("1d")); | ||||
|         } | ||||
| 
 | ||||
|         try{ | ||||
|             new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); | ||||
|             fail("Expected exception"); | ||||
|         } catch (Throwable t){ | ||||
|             String m = t.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("causal") && m.contains("1d")); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -356,6 +356,10 @@ public class KerasLayer { | ||||
|         return this.layer; | ||||
|     } | ||||
| 
 | ||||
|     public void setLayer(Layer layer){ | ||||
|         this.layer = layer; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Whether this Keras layer maps to a DL4J Vertex. | ||||
|      * | ||||
|  | ||||
| @ -233,6 +233,7 @@ public class KerasLayerConfiguration { | ||||
|     private final String LAYER_BORDER_MODE_SAME = "same"; | ||||
|     private final String LAYER_BORDER_MODE_VALID = "valid"; | ||||
|     private final String LAYER_BORDER_MODE_FULL = "full"; | ||||
|     private final String LAYER_BORDER_MODE_CAUSAL = "causal"; | ||||
| 
 | ||||
|     /* Noise layers */ | ||||
|     private final String LAYER_FIELD_RATE = "rate"; | ||||
|  | ||||
| @ -124,7 +124,26 @@ public class KerasInput extends KerasLayer { | ||||
|                 myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); | ||||
|                 break; | ||||
|             case 2: | ||||
|                 myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); | ||||
|                 if(this.dimOrder != null) { | ||||
|                     switch (this.dimOrder) { | ||||
|                         case TENSORFLOW:    //NWC == channels_last | ||||
|                             myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); | ||||
|                             break; | ||||
|                         case THEANO:        //NCW == channels_first | ||||
|                             myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); | ||||
|                             break; | ||||
|                         case NONE: | ||||
|                             //Assume RNN in [mb, seqLen, size] format | ||||
|                             myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); | ||||
|                             break; | ||||
|                         default: | ||||
|                             throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder); | ||||
|                     } | ||||
|                 } else { | ||||
|                     //Assume RNN in [mb, seqLen, size] format | ||||
|                     myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); | ||||
|                 } | ||||
| 
 | ||||
|                 break; | ||||
|             case 3: | ||||
|                 switch (this.dimOrder) { | ||||
|  | ||||
| @ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer; | ||||
| import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||||
| import org.nd4j.linalg.activations.Activation; | ||||
| import org.nd4j.linalg.lossfunctions.LossFunctions; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| @ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer { | ||||
|      */ | ||||
|     public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException { | ||||
|         if (type instanceof InputType.InputTypeFeedForward) { | ||||
|             this.layer = new LossLayer.Builder(loss).name(this.layerName).build(); | ||||
|             this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); | ||||
|         } | ||||
|         else if (type instanceof  InputType.InputTypeRecurrent) { | ||||
|             this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build(); | ||||
|             this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); | ||||
|         } | ||||
|         else if (type instanceof InputType.InputTypeConvolutional) { | ||||
|             this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build(); | ||||
|             this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); | ||||
|         } else { | ||||
|             throw new UnsupportedKerasConfigurationException("Unsupported output layer type" | ||||
|                     + "got : " + type.toString()); | ||||
|  | ||||
| @ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer { | ||||
|     public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig) | ||||
|             throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { | ||||
|         super(layerConfig, enforceTrainingConfig); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|  | ||||
| @ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution { | ||||
|                     break; | ||||
| 
 | ||||
|                 case THEANO: | ||||
|                     paramValue = kerasParamValue.permute(2, 1, 0); | ||||
|                     paramValue = paramValue.reshape( | ||||
|                             paramValue.size(0), paramValue.size(1), | ||||
|                             paramValue.size(2), 1).dup(); | ||||
|                     for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) { | ||||
|                         INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup(); | ||||
|                         double[] flattenedFilter = copyFilter.ravel().data().asDouble(); | ||||
|                         ArrayUtils.reverse(flattenedFilter); | ||||
|                         INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape()); | ||||
|                         INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3); | ||||
|                         inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType())); | ||||
|                     } | ||||
|                     //Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1] | ||||
|                     long k = kerasParamValue.size(0); | ||||
|                     long nIn = kerasParamValue.size(1); | ||||
|                     long nOut = kerasParamValue.size(2); | ||||
|                     paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1); | ||||
|                     break; | ||||
|                 default: | ||||
|                     throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder()); | ||||
|  | ||||
| @ -264,7 +264,8 @@ public class KerasConvolutionUtils { | ||||
|         } else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) || | ||||
|                 borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) { | ||||
|             convolutionMode = ConvolutionMode.Truncate; | ||||
| 
 | ||||
|         } else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) { | ||||
|             convolutionMode = ConvolutionMode.Causal; | ||||
|         } else { | ||||
|             throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode); | ||||
|         } | ||||
|  | ||||
| @ -23,11 +23,13 @@ import lombok.val; | ||||
| import org.deeplearning4j.nn.api.layers.LayerConstraint; | ||||
| import org.deeplearning4j.nn.conf.InputPreProcessor; | ||||
| import org.deeplearning4j.nn.conf.inputs.InputType; | ||||
| import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.InputTypeUtil; | ||||
| import org.deeplearning4j.nn.conf.layers.LSTM; | ||||
| import org.deeplearning4j.nn.conf.layers.Layer; | ||||
| import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; | ||||
| import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; | ||||
| import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||||
| @ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer { | ||||
|                 .biasInit(0.0) // TODO: this is incorrect | ||||
|                 .l1(this.weightL1Regularization) | ||||
|                 .l2(this.weightL2Regularization); | ||||
|         Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); | ||||
|         if(nIn != null) | ||||
|             builder.setNIn(nIn); | ||||
|         if (biasConstraint != null) | ||||
|             builder.constrainBias(biasConstraint); | ||||
|         if (weightConstraint != null) | ||||
| @ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer { | ||||
|             log.warn("Attemping to set weights for unknown parameters: " | ||||
|                     + unknownParamNames.substring(1, unknownParamNames.length() - 1)); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         FeedForwardLayer ffl; | ||||
|         if(this.layer instanceof BaseWrapperLayer){ | ||||
|             BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; | ||||
|             ffl = (FeedForwardLayer)bwl.getUnderlying(); | ||||
|         } else { | ||||
|             ffl = (FeedForwardLayer) this.layer; | ||||
|         } | ||||
|         if(ffl.getNIn() != wRows){ | ||||
|             //Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config) | ||||
|             //We can reliably infer nIn from the shape of the weights array however | ||||
|             ffl.setNIn(wRows); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|  | ||||
| @ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j; | ||||
| import org.deeplearning4j.nn.api.layers.LayerConstraint; | ||||
| import org.deeplearning4j.nn.conf.InputPreProcessor; | ||||
| import org.deeplearning4j.nn.conf.inputs.InputType; | ||||
| import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.InputTypeUtil; | ||||
| import org.deeplearning4j.nn.conf.layers.Layer; | ||||
| import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; | ||||
| import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; | ||||
| import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; | ||||
| import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||||
| @ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer { | ||||
|                 .biasInit(0.0) | ||||
|                 .l1(this.weightL1Regularization) | ||||
|                 .l2(this.weightL2Regularization); | ||||
|         Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); | ||||
|         if(nIn != null) | ||||
|             builder.setNIn(nIn); | ||||
|         if (biasConstraint != null) | ||||
|             builder.constrainBias(biasConstraint); | ||||
|         if (weightConstraint != null) | ||||
| @ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer { | ||||
|             log.warn("Attemping to set weights for unknown parameters: " | ||||
|                     + unknownParamNames.substring(1, unknownParamNames.length() - 1)); | ||||
|         } | ||||
| 
 | ||||
|         FeedForwardLayer ffl; | ||||
|         if(this.layer instanceof BaseWrapperLayer){ | ||||
|             BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; | ||||
|             ffl = (FeedForwardLayer)bwl.getUnderlying(); | ||||
|         } else { | ||||
|             ffl = (FeedForwardLayer) this.layer; | ||||
|         } | ||||
|         if(ffl.getNIn() != W.rows()){ | ||||
|             //Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config) | ||||
|             //We can reliably infer nIn from the shape of the weights array however | ||||
|             ffl.setNIn(W.rows()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer { | ||||
|     @Override | ||||
|     public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException { | ||||
| 
 | ||||
|         Map<String, INDArray> forwardWeights = getUnderlyingWeights(weights, "forward"); | ||||
|         Map<String, INDArray> backwardWeights = getUnderlyingWeights(weights, "backward"); | ||||
|         Map<String, INDArray> forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward"); | ||||
|         Map<String, INDArray> backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward"); | ||||
| 
 | ||||
|         this.weights = new HashMap<>(); | ||||
| 
 | ||||
| @ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     private Map<String, INDArray> getUnderlyingWeights(Map<String, INDArray> weights, String direction) | ||||
|     private Map<String, INDArray> getUnderlyingWeights(Layer l, Map<String, INDArray> weights, String direction) | ||||
|             throws InvalidKerasConfigurationException { | ||||
|         int keras1SubstringLength; | ||||
|         if (kerasRnnlayer instanceof KerasLSTM) | ||||
| @ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer { | ||||
|             weights = newWeights; | ||||
|         } | ||||
| 
 | ||||
|         Layer layerBefore = kerasRnnlayer.getLayer(); | ||||
|         kerasRnnlayer.setLayer(l); | ||||
|         kerasRnnlayer.setWeights(weights); | ||||
|         return kerasRnnlayer.getWeights(); | ||||
|         Map<String,INDArray> ret = kerasRnnlayer.getWeights(); | ||||
|         kerasRnnlayer.setLayer(layerBefore); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -505,6 +505,17 @@ public class KerasLayerUtils { | ||||
|         return nOut; | ||||
|     } | ||||
| 
 | ||||
|     public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException { | ||||
|         Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); | ||||
|         if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){ | ||||
|             Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM()); | ||||
|             if(id instanceof Number){ | ||||
|                 return ((Number)id).intValue(); | ||||
|             } | ||||
|         } | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get dropout from Keras layer configuration. | ||||
|      * | ||||
|  | ||||
| @ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass; | ||||
| import org.deeplearning4j.gradientcheck.GradientCheckUtil; | ||||
| import org.deeplearning4j.nn.api.Layer; | ||||
| import org.deeplearning4j.nn.api.layers.IOutputLayer; | ||||
| import org.deeplearning4j.nn.conf.ConvolutionMode; | ||||
| import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.LossLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | ||||
| @ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.function.BiFunction; | ||||
| import org.nd4j.linalg.function.Function; | ||||
| import org.nd4j.linalg.learning.config.NoOp; | ||||
| import org.nd4j.linalg.lossfunctions.LossFunctions; | ||||
| import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; | ||||
| @ -58,10 +62,7 @@ import java.io.InputStream; | ||||
| import java.net.URL; | ||||
| import java.nio.file.Files; | ||||
| import java.nio.file.StandardCopyOption; | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.Random; | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.Assert.*; | ||||
| 
 | ||||
| @ -86,7 +87,16 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     @Rule | ||||
|     public final TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     @Test(expected = IllegalStateException.class) | ||||
|     public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() { | ||||
|         @Override | ||||
|         public INDArray apply(String s, INDArray array) { | ||||
|             if(array.rank() == 3) | ||||
|                 return array.permute(0, 2, 1);    //NWC to NCW | ||||
|             return array; | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|         @Test(expected = IllegalStateException.class) | ||||
|     public void fileNotFoundEndToEnd() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/foo/bar.h5"; | ||||
|         importEndModelTest(modelPath, null, true, true, false, false); | ||||
| @ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     public void importImdbLstmTfKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbLstmThKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbLstmTfKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbLstmThKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|         String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + | ||||
|                 "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|         model.summary(); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCausalCon1D() throws Exception { | ||||
|         String[] names = new String[]{ | ||||
|                 "causal_conv1d_k2_s1_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k2_s1_d2_cl_model.h5", | ||||
|                 "causal_conv1d_k2_s2_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k2_s3_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k3_s1_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k3_s1_d2_cl_model.h5", | ||||
|                 "causal_conv1d_k3_s2_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k3_s3_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k4_s1_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k4_s1_d2_cl_model.h5", | ||||
|                 "causal_conv1d_k4_s2_d1_cl_model.h5", | ||||
|                 "causal_conv1d_k4_s3_d1_cl_model.h5" | ||||
|         }; | ||||
| 
 | ||||
|         for(String name : names ){ | ||||
|             System.out.println("Starting test: " + name); | ||||
|             String modelPath = "modelimport/keras/examples/causal_conv1d/" + name; | ||||
|             String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); | ||||
|             Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() { | ||||
|                 @Override | ||||
|                 public INDArray apply(INDArray i) { | ||||
|                     //NWC to NCW | ||||
|                     return i.permute(0, 2, 1); | ||||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|             MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, | ||||
|                     true, true, false, f, nwc2ncwExpected); | ||||
|             Layer l = net.getLayer(0); | ||||
|             Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); | ||||
|             assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCon1D() throws Exception { | ||||
|         String[] names = new String[]{ | ||||
|                 "conv1d_k2_s1_d1_cf_same_model.h5", | ||||
|                 "conv1d_k2_s1_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k2_s1_d1_cl_same_model.h5", | ||||
|                 "conv1d_k2_s1_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k2_s1_d2_cf_same_model.h5", | ||||
|                 "conv1d_k2_s1_d2_cf_valid_model.h5", | ||||
|                 "conv1d_k2_s1_d2_cl_same_model.h5", | ||||
|                 "conv1d_k2_s1_d2_cl_valid_model.h5", | ||||
|                 "conv1d_k2_s2_d1_cf_same_model.h5", | ||||
|                 "conv1d_k2_s2_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k2_s2_d1_cl_same_model.h5", | ||||
|                 "conv1d_k2_s2_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k2_s3_d1_cf_same_model.h5", | ||||
|                 "conv1d_k2_s3_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k2_s3_d1_cl_same_model.h5", | ||||
|                 "conv1d_k2_s3_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k3_s1_d1_cf_same_model.h5", | ||||
|                 "conv1d_k3_s1_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k3_s1_d1_cl_same_model.h5", | ||||
|                 "conv1d_k3_s1_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k3_s1_d2_cf_same_model.h5", | ||||
|                 "conv1d_k3_s1_d2_cf_valid_model.h5", | ||||
|                 "conv1d_k3_s1_d2_cl_same_model.h5", | ||||
|                 "conv1d_k3_s1_d2_cl_valid_model.h5", | ||||
|                 "conv1d_k3_s2_d1_cf_same_model.h5", | ||||
|                 "conv1d_k3_s2_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k3_s2_d1_cl_same_model.h5", | ||||
|                 "conv1d_k3_s2_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k3_s3_d1_cf_same_model.h5", | ||||
|                 "conv1d_k3_s3_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k3_s3_d1_cl_same_model.h5", | ||||
|                 "conv1d_k3_s3_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k4_s1_d1_cf_same_model.h5", | ||||
|                 "conv1d_k4_s1_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k4_s1_d1_cl_same_model.h5", | ||||
|                 "conv1d_k4_s1_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k4_s1_d2_cf_same_model.h5", | ||||
|                 "conv1d_k4_s1_d2_cf_valid_model.h5", | ||||
|                 "conv1d_k4_s1_d2_cl_same_model.h5", | ||||
|                 "conv1d_k4_s1_d2_cl_valid_model.h5", | ||||
|                 "conv1d_k4_s2_d1_cf_same_model.h5", | ||||
|                 "conv1d_k4_s2_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k4_s2_d1_cl_same_model.h5", | ||||
|                 "conv1d_k4_s2_d1_cl_valid_model.h5", | ||||
|                 "conv1d_k4_s3_d1_cf_same_model.h5", | ||||
|                 "conv1d_k4_s3_d1_cf_valid_model.h5", | ||||
|                 "conv1d_k4_s3_d1_cl_same_model.h5", | ||||
|                 "conv1d_k4_s3_d1_cl_valid_model.h5", | ||||
|         }; | ||||
| 
 | ||||
|         for(String name : names ){ | ||||
|             System.out.println("Starting test: " + name); | ||||
|             String modelPath = "modelimport/keras/examples/conv1d/" + name; | ||||
|             String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); | ||||
|             Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() { | ||||
|                 @Override | ||||
|                 public INDArray apply(INDArray i) { | ||||
|                     //NWC to NCW | ||||
|                     return i.permute(0, 2, 1); | ||||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|             BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() { | ||||
|                 @Override | ||||
|                 public INDArray apply(String s, INDArray array) { | ||||
| //                    if("conv".equals(s)){ | ||||
|                         return array.permute(0, 2, 1); | ||||
| //                    } | ||||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|             importEndModelTest(modelPath, inputsOutputPath, true, true, | ||||
|                     true, true, false, f, f2); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception { | ||||
|         return importFunctionalModelH5Test(modelPath, null, false); | ||||
|     } | ||||
| @ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, | ||||
|                                     boolean checkGradients, boolean enforceTrainingConfig) throws Exception { | ||||
|         return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null); | ||||
|     } | ||||
| 
 | ||||
|     public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, | ||||
|                                                 boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function<INDArray,INDArray> inputPreProc, | ||||
|                                                 BiFunction<String,INDArray,INDArray> expectedPreProc) throws Exception { | ||||
|         MultiLayerNetwork model; | ||||
|         try(InputStream is = Resources.asStream(modelPath)) { | ||||
|             File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); | ||||
| @ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
| 
 | ||||
|             if (checkPredictions) { | ||||
|                 INDArray input = getInputs(outputsArchive, tfOrdering)[0]; | ||||
|                 if(inputPreProc != null) | ||||
|                     input = inputPreProc.apply(input); | ||||
| 
 | ||||
|                 Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering); | ||||
|                 for (int i = 0; i < model.getLayers().length; i++) { | ||||
|                     String layerName = model.getLayerNames().get(i); | ||||
|                     if (activationsKeras.containsKey(layerName)) { | ||||
|                         INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1); | ||||
|                         if (activationsDl4j.shape().length == 3) | ||||
|                             activationsDl4j = activationsDl4j.permute(0, 2, 1); | ||||
|                         compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS); | ||||
| 
 | ||||
|                         INDArray exp = activationsKeras.get(layerName); | ||||
|                         if(expectedPreProc != null) | ||||
|                             exp = expectedPreProc.apply(layerName, exp); | ||||
|                         compareINDArrays(layerName, exp, activationsDl4j, EPS); | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|                 INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0]; | ||||
|                 INDArray predictionsDl4j = model.output(input, false); | ||||
|                 if(expectedPreProc != null) | ||||
|                     predictionsKeras = expectedPreProc.apply("output", predictionsKeras); | ||||
|                 compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS); | ||||
|                 INDArray outputs = getOutputs(outputsArchive, true)[0]; | ||||
| 
 | ||||
| @ -680,7 +817,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|                 } | ||||
|                 val nOut = (int) outputs.size(-1); | ||||
| 
 | ||||
|                 compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); | ||||
|                 if(checkAuc) | ||||
|                     compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); | ||||
|             } | ||||
| 
 | ||||
|             if (checkGradients && ! SKIP_GRAD_CHECKS) { | ||||
| @ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|         return predictions; | ||||
|     } | ||||
| 
 | ||||
|     private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) { | ||||
|         INDArray diff = a.sub(b.castTo(a.dataType())); | ||||
|     private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) { | ||||
|         if(!expected.equalShapes(actual)){ | ||||
|             throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape())); | ||||
|         } | ||||
|         INDArray diff = expected.sub(actual.castTo(expected.dataType())); | ||||
|         double min = diff.minNumber().doubleValue(); | ||||
|         double max = diff.maxNumber().doubleValue(); | ||||
|         log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max); | ||||
|         log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max); | ||||
|         double threshold = 1e-7; | ||||
|         double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue())); | ||||
|         double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue())); | ||||
|         double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue())); | ||||
|         double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue())); | ||||
| 
 | ||||
|         // skip too small absolute inputs | ||||
|         if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { | ||||
|             assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps)); | ||||
|             boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps); | ||||
|             assertTrue("Output differs: " + label, eq); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, | ||||
|  | ||||
| @ -69,6 +69,18 @@ package org.deeplearning4j.nn.conf; | ||||
|  * <br> | ||||
|  * <br> | ||||
|  * <br> | ||||
|  * <b>Causal</b>: Causal padding mode can only be used for 1D convolutional neural networks.<br> | ||||
|  * The motivation behind causal padding mode is that the output time steps depend only on current and past time steps.<br> | ||||
|  * That is, out[t] (for time t) depends on only on values in[T] for t < T<br> | ||||
|  * The output size of 1D convolution/subsampling layers is the same as with SAME convolution mode - | ||||
|  * i.e., outSize = ceil( inputSize / stride )<br> | ||||
|  * Padding is also the same as SAME mode, but all padding in on the left (start of sequence) instead of being on both | ||||
|  * left and right of the input<br> | ||||
|  * For more details on causal convolutions, see <a href="https://arxiv.org/abs/1609.03499">WaveNet: A Generative Model For Audio</a>, | ||||
|  * section 2.1. | ||||
|  * <br> | ||||
|  * <br> | ||||
|  * <br> | ||||
|  * For further information on output sizes for convolutional neural networks, see the "Spatial arrangement" section at | ||||
|  * <a href="http://cs231n.github.io/convolutional-networks/">http://cs231n.github.io/convolutional-networks/</a> | ||||
|  * | ||||
| @ -76,6 +88,6 @@ package org.deeplearning4j.nn.conf; | ||||
|  */ | ||||
| public enum ConvolutionMode { | ||||
| 
 | ||||
|     Strict, Truncate, Same | ||||
|     Strict, Truncate, Same, Causal | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -124,6 +124,11 @@ public class Convolution1DLayer extends ConvolutionLayer { | ||||
|             this.setKernelSize((int[]) null); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             return true; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * @param kernelSize Kernel size | ||||
|          * @param stride Stride | ||||
|  | ||||
| @ -163,6 +163,12 @@ public class Convolution3D extends ConvolutionLayer { | ||||
|             super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             //Causal convolution - allowed for 1D only | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) { | ||||
|             super(kernelSize, stride, padding, dilation, 3); | ||||
|         } | ||||
|  | ||||
| @ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.ConvolutionParamInitializer; | ||||
| import org.deeplearning4j.optimize.api.TrainingListener; | ||||
| import org.deeplearning4j.util.ConvolutionUtils; | ||||
| import org.deeplearning4j.util.ValidationUtils; | ||||
| import org.nd4j.base.Preconditions; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| @ -283,6 +284,12 @@ public class ConvolutionLayer extends FeedForwardLayer { | ||||
|             super(); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             //Causal convolution - allowed for 1D only | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * Size of the convolution rows/columns | ||||
|          * | ||||
| @ -456,6 +463,14 @@ public class ConvolutionLayer extends FeedForwardLayer { | ||||
| 
 | ||||
|         protected BaseConvBuilder() {} | ||||
| 
 | ||||
|         protected abstract boolean allowCausal(); | ||||
| 
 | ||||
|         protected void setConvolutionMode(ConvolutionMode convolutionMode){ | ||||
|             Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + | ||||
|                     " convolutional neural network layers"); | ||||
|             this.convolutionMode = convolutionMode; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * If true (default): include bias parameters in the model. False: no bias. | ||||
|          * | ||||
|  | ||||
| @ -133,6 +133,12 @@ public class Deconvolution2D extends ConvolutionLayer { | ||||
|             super(); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             //Causal convolution - allowed for 1D only | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details | ||||
|          * | ||||
|  | ||||
| @ -133,6 +133,12 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { | ||||
|             super(); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             //Causal convolution - allowed for 1D only | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * Set channels multiplier for depth-wise convolution | ||||
|          * | ||||
|  | ||||
| @ -184,6 +184,12 @@ public class SeparableConvolution2D extends ConvolutionLayer { | ||||
|             super(); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             //Causal convolution - allowed for 1D only | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * Set channels multiplier of channels-wise step in separable convolution | ||||
|          * | ||||
|  | ||||
| @ -167,6 +167,11 @@ public class Subsampling1DLayer extends SubsamplingLayer { | ||||
|             this(poolingType, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             return true; | ||||
|         } | ||||
| 
 | ||||
|         public Builder() { | ||||
|             this(DEFAULT_POOLING, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING); | ||||
|         } | ||||
|  | ||||
| @ -431,6 +431,12 @@ public class Subsampling3DLayer extends NoParamLayer { | ||||
|             this.setPoolingType(poolingType); | ||||
|         } | ||||
| 
 | ||||
|         protected void setConvolutionMode(ConvolutionMode convolutionMode){ | ||||
|             Preconditions.checkState(convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + | ||||
|                     " convolutional neural network layers"); | ||||
|             this.convolutionMode = convolutionMode; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details | ||||
|          * | ||||
|  | ||||
| @ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; | ||||
| import org.deeplearning4j.optimize.api.TrainingListener; | ||||
| import org.deeplearning4j.util.ConvolutionUtils; | ||||
| import org.deeplearning4j.util.ValidationUtils; | ||||
| import org.nd4j.base.Preconditions; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| @ -270,6 +271,12 @@ public class SubsamplingLayer extends NoParamLayer { | ||||
|             super(poolingType); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected boolean allowCausal() { | ||||
|             //Only conv1d/subsampling1d can use causal mode | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * Kernel size | ||||
|          * | ||||
| @ -449,6 +456,14 @@ public class SubsamplingLayer extends NoParamLayer { | ||||
|             this.eps = eps; | ||||
|         } | ||||
| 
 | ||||
|         protected abstract boolean allowCausal(); | ||||
| 
 | ||||
|         public void setConvolutionMode(ConvolutionMode convolutionMode){ | ||||
|             Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + | ||||
|                     " convolutional neural network layers"); | ||||
|             this.convolutionMode = convolutionMode; | ||||
|         } | ||||
| 
 | ||||
|         /** | ||||
|          * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details | ||||
|          * | ||||
|  | ||||
| @ -18,18 +18,30 @@ package org.deeplearning4j.nn.layers.convolution; | ||||
| 
 | ||||
| import org.deeplearning4j.exception.DL4JInvalidInputException; | ||||
| import org.deeplearning4j.nn.api.MaskState; | ||||
| import org.deeplearning4j.nn.conf.ConvolutionMode; | ||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||||
| import org.deeplearning4j.nn.gradient.DefaultGradient; | ||||
| import org.deeplearning4j.nn.gradient.Gradient; | ||||
| import org.deeplearning4j.nn.params.ConvolutionParamInitializer; | ||||
| import org.deeplearning4j.util.Convolution1DUtils; | ||||
| import org.deeplearning4j.util.ConvolutionUtils; | ||||
| import org.nd4j.base.Preconditions; | ||||
| import org.nd4j.linalg.activations.IActivation; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D; | ||||
| import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative; | ||||
| import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; | ||||
| import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode; | ||||
| import org.nd4j.linalg.api.shape.LongShapeDescriptor; | ||||
| import org.nd4j.linalg.factory.Broadcast; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.primitives.Pair; | ||||
| import org.deeplearning4j.nn.workspace.ArrayType; | ||||
| import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; | ||||
| 
 | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * 1D (temporal) convolutional layer. Currently, we just subclass off the | ||||
| @ -70,6 +82,52 @@ public class Convolution1DLayer extends ConvolutionLayer { | ||||
|             Broadcast.mul(epsilon, maskOut, epsilon, 0, 2); | ||||
|         } | ||||
| 
 | ||||
|         if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){ | ||||
|             Pair<INDArray,INDArray> fwd = causalConv1dForward(); | ||||
|             IActivation afn = layerConf().getActivationFn(); | ||||
|             INDArray delta = afn.backprop(fwd.getFirst(), epsilon).getFirst(); //TODO handle activation function params | ||||
| 
 | ||||
|             //TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support | ||||
|             org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf(); | ||||
|             Conv1DConfig conf = Conv1DConfig.builder() | ||||
|                     .k(c.getKernelSize()[0]) | ||||
|                     .s(c.getStride()[0]) | ||||
|                     .d(c.getDilation()[0]) | ||||
|                     .p(c.getPadding()[0]) | ||||
|                     .dataFormat(Conv1DConfig.NCW) | ||||
|                     .paddingMode(PaddingMode.CAUSAL) | ||||
|                     .build(); | ||||
| 
 | ||||
|             INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); | ||||
|             w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0);   //[oC, iC, k, 1] to [k, iC, oC] | ||||
| 
 | ||||
|             INDArray[] inputArrs; | ||||
|             INDArray[] outputArrs; | ||||
|             INDArray wg = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY); | ||||
|             wg = wg.reshape(wg.ordering(), wg.size(0), wg.size(1), wg.size(2)).permute(2, 1, 0);    //[oC, iC, k, 1] -> [kW, iC, oC] | ||||
|             INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); | ||||
|             if(layerConf().hasBias()){ | ||||
|                 INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); | ||||
|                 b = b.reshape(b.length()); | ||||
|                 inputArrs = new INDArray[]{input.castTo(w.dataType()), w, b, delta}; | ||||
|                 INDArray bg = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY); | ||||
|                 bg = bg.reshape(bg.length()); | ||||
|                 outputArrs = new INDArray[]{epsOut, wg, bg}; | ||||
|             } else { | ||||
|                 inputArrs = new INDArray[]{input.castTo(w.dataType()), w, delta}; | ||||
|                 outputArrs = new INDArray[]{epsOut, wg}; | ||||
|             } | ||||
|             Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf); | ||||
|             Nd4j.exec(op); | ||||
| 
 | ||||
|             Gradient retGradient = new DefaultGradient(); | ||||
|             if(layerConf().hasBias()){ | ||||
|                 retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY)); | ||||
|             } | ||||
|             retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c'); | ||||
|             return new Pair<>(retGradient, epsOut); | ||||
|         } | ||||
| 
 | ||||
|         // add singleton fourth dimension to input and next layer's epsilon | ||||
|         epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1); | ||||
|         INDArray origInput = input; | ||||
| @ -98,6 +156,12 @@ public class Convolution1DLayer extends ConvolutionLayer { | ||||
|     @Override | ||||
|     protected Pair<INDArray,INDArray> preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { | ||||
|         assertInputSet(false); | ||||
| 
 | ||||
|         if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){ | ||||
|             return causalConv1dForward(); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         INDArray origInput = input; | ||||
|         input = input.reshape(input.size(0), input.size(1), input.size(2), 1); | ||||
| 
 | ||||
| @ -113,6 +177,36 @@ public class Convolution1DLayer extends ConvolutionLayer { | ||||
|         return preOutput; | ||||
|     } | ||||
| 
 | ||||
|     protected Pair<INDArray,INDArray> causalConv1dForward(){ | ||||
|         //TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support | ||||
|         org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf(); | ||||
|         Conv1DConfig conf = Conv1DConfig.builder() | ||||
|                 .k(c.getKernelSize()[0]) | ||||
|                 .s(c.getStride()[0]) | ||||
|                 .d(c.getDilation()[0]) | ||||
|                 .p(c.getPadding()[0]) | ||||
|                 .dataFormat(Conv1DConfig.NCW) | ||||
|                 .paddingMode(PaddingMode.CAUSAL) | ||||
|                 .build(); | ||||
|         INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); | ||||
|         w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0);   //[oC, iC, k, 1] to [k, iC, oC] | ||||
| 
 | ||||
|         INDArray[] inputs; | ||||
|         if(layerConf().hasBias()){ | ||||
|             INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); | ||||
|             b = b.reshape(b.length()); | ||||
|             inputs = new INDArray[]{input.castTo(w.dataType()), w, b}; | ||||
|         } else { | ||||
|             inputs = new INDArray[]{input.castTo(w.dataType()), w}; | ||||
|         } | ||||
| 
 | ||||
|         Conv1D op = new Conv1D(inputs, null, conf); | ||||
|         List<LongShapeDescriptor> outShape = op.calculateOutputShape(); | ||||
|         op.setOutputArgument(0, Nd4j.create(outShape.get(0), false)); | ||||
|         Nd4j.exec(op); | ||||
|         return new Pair<>(op.getOutputArgument(0), null); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ | ||||
|         INDArray act4d = super.activate(training, workspaceMgr); | ||||
|  | ||||
| @ -66,7 +66,7 @@ public class Convolution1DUtils { | ||||
|     public static long getOutputSize(long inH, int kernel, int strides, int padding, | ||||
|                                     ConvolutionMode convolutionMode, int dilation) { | ||||
|         long eKernel = effectiveKernelSize(kernel, dilation); | ||||
|         if (convolutionMode == ConvolutionMode.Same) { | ||||
|         if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { | ||||
|             return (int) Math.ceil(inH / ((double) strides)); | ||||
|         } | ||||
|         return (inH - eKernel + 2 * padding) / strides + 1; | ||||
| @ -92,7 +92,7 @@ public class Convolution1DUtils { | ||||
|         boolean atrous = (eKernel == kernel); | ||||
|         validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inH, atrous); | ||||
| 
 | ||||
|         if (convolutionMode == ConvolutionMode.Same) { | ||||
|         if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { | ||||
|             int outH = (int) Math.ceil(inH / ((double) strides)); | ||||
|             return outH; | ||||
|         } | ||||
| @ -106,8 +106,9 @@ public class Convolution1DUtils { | ||||
|                                       boolean atrous) { | ||||
| 
 | ||||
|         int inH = inShape; | ||||
|         boolean t = convolutionMode == ConvolutionMode.Truncate; | ||||
| 
 | ||||
|         if (convolutionMode != ConvolutionMode.Same && (eKernel <= 0 || eKernel > inH + 2 * padding)) { | ||||
|         if (t && (eKernel <= 0 || eKernel > inH + 2 * padding)) { | ||||
|             StringBuilder sb = new StringBuilder(); | ||||
|             sb.append("Invalid input data or configuration: "); | ||||
|             if (atrous) sb.append("effective "); | ||||
|  | ||||
| @ -121,7 +121,7 @@ public class ConvolutionUtils { | ||||
|         int[] inShape = new int[]{inH, inW}; | ||||
|         validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous); | ||||
| 
 | ||||
|         if (convolutionMode == ConvolutionMode.Same) { | ||||
|         if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { | ||||
| 
 | ||||
|             int outH = (int) Math.ceil(inH / ((double) strides[0])); | ||||
|             int outW = (int) Math.ceil(inW / ((double) strides[1])); | ||||
| @ -142,7 +142,9 @@ public class ConvolutionUtils { | ||||
|         int inH = inShape[0]; | ||||
|         int inW = inShape[1]; | ||||
| 
 | ||||
|         if (convolutionMode != ConvolutionMode.Same && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) { | ||||
|         boolean t = (convolutionMode == ConvolutionMode.Truncate); | ||||
| 
 | ||||
|         if (t && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) { | ||||
|             StringBuilder sb = new StringBuilder(); | ||||
|             sb.append("Invalid input data or configuration: "); | ||||
|             if (atrous) sb.append("effective "); | ||||
| @ -158,7 +160,7 @@ public class ConvolutionUtils { | ||||
|             throw new DL4JInvalidInputException(sb.toString()); | ||||
|         } | ||||
| 
 | ||||
|         if (convolutionMode != ConvolutionMode.Same && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) { | ||||
|         if (t && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) { | ||||
|             StringBuilder sb = new StringBuilder(); | ||||
|             sb.append("Invalid input data or configuration: "); | ||||
|             if (atrous) sb.append("effective "); | ||||
| @ -175,8 +177,7 @@ public class ConvolutionUtils { | ||||
|             throw new DL4JInvalidInputException(sb.toString()); | ||||
|         } | ||||
| 
 | ||||
|         if (eKernel.length == 3 && convolutionMode != ConvolutionMode.Same | ||||
|                 && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) { | ||||
|         if (eKernel.length == 3 && t && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) { | ||||
|             int inD = inShape[2]; | ||||
|             StringBuilder sb = new StringBuilder(); | ||||
|             sb.append("Invalid input data or configuration: "); | ||||
| @ -615,7 +616,7 @@ public class ConvolutionUtils { | ||||
|      */ | ||||
|     public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm){ | ||||
|         Preconditions.checkState(in.rank()==2, "Rank must be 2 for cnn1d mask array - shape ", in.shape()); | ||||
|         if(cm == ConvolutionMode.Same && stride == 1 ){ | ||||
|         if((cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) && stride == 1 ){ | ||||
|             return in; | ||||
|         } | ||||
| 
 | ||||
| @ -630,7 +631,7 @@ public class ConvolutionUtils { | ||||
|         int[] k = new int[]{kernel,1}; | ||||
|         int[] s = new int[]{stride, 1}; | ||||
|         int[] d = new int[]{dilation, 1}; | ||||
|         if (cm == ConvolutionMode.Same) { | ||||
|         if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) { | ||||
|             outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation | ||||
|         } else { | ||||
|             pad = new int[]{padding, 0}; | ||||
| @ -645,7 +646,7 @@ public class ConvolutionUtils { | ||||
|                 .sH(s[0]).sW(s[1]) | ||||
|                 .pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1]) | ||||
|                 .dH(d[0]).dW(d[1]) | ||||
|                 .isSameMode(cm== ConvolutionMode.Same) | ||||
|                 .isSameMode(cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) | ||||
|                 .isNHWC(false) | ||||
|                 .build()); | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user