diff --git a/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp b/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp index 2961e3bcf..23b2ec172 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp @@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) { auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1 auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] - auto b = INPUT_VARIABLE(3); // biases [1 x 2*inSize] + auto b = INPUT_VARIABLE(3); // biases [2*inSize] auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 0b5a4c03f..1821a30a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -6511,4 +6511,22 @@ public class SameDiff extends SDBaseOps { public String generateNewVarName(String base, int argIndex) { return generateNewVarName(base, argIndex, true); } + + /** + * Returns an unused variable name of the format <base>_#. + * + * Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}. + */ + public String generateDistinctCustomVariableName(String base){ + if(!variables.containsKey(base)) + return base; + + int inc = 1; + + while(variables.containsKey(base + "_" + inc)){ + inc++; + } + + return base + "_" + inc; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index f47f32b87..de0114b92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff.ops; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.impl.layers.recurrent.*; @@ -23,6 +24,15 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*; import java.util.Arrays; import java.util.List; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; +import org.nd4j.linalg.primitives.Pair; /** * SameDiff Recurrent Neural Network operations
@@ -39,90 +49,163 @@ public class SDRNN extends SDOps { /** - * The gru cell - * - * @param configuration the configuration to use - * @return + * See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}. */ - public List gru(GRUCellConfiguration configuration) { - GRUCell c = new GRUCell(sd, configuration); - return Arrays.asList(c.outputVariables()); + public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { + GRUCell c = new GRUCell(sd, x, hLast, weights); + return new GRUCellOutputs(c.outputVariables()); } /** - * The gru cell + * The GRU cell. Does a single time step operation. * - * @param baseName the base name for the gru cell - * @param configuration the configuration to use - * @return + * @param baseName The base name for the gru cell + * @param x Input, with shape [batchSize, inSize] + * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] + * @param weights The cell's weights. + * @return The cell's outputs. */ - public List gru(String baseName, GRUCellConfiguration configuration) { - GRUCell c = new GRUCell(sd, configuration); - return Arrays.asList(c.outputVariables(baseName)); - } - - - /** - * LSTM unit - * - * @param baseName the base name for outputs - * @param configuration the configuration to use - * @return - */ - public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) { - return new LSTMCell(sd, configuration).outputVariables(baseName)[0]; - } - - public List lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){ - SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name); - return Arrays.asList(v); - } - - public List lstmLayer(String name, LSTMConfiguration configuration){ - SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name); - return Arrays.asList(v); + public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { + GRUCell c = new GRUCell(sd, x, hLast, weights); + return new GRUCellOutputs(c.outputVariables(baseName)); } /** - * Simple recurrent unit - * - * @param configuration the configuration for the sru - * @return + * See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}. */ - public SDVariable sru(SRUConfiguration configuration) { - return new SRU(sd, configuration).outputVariables()[0]; + public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + LSTMWeights weights, LSTMConfiguration config){ + LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); + return new LSTMCellOutputs(c.outputVariables()); } /** - * Simiple recurrent unit + * The LSTM cell. Does a single time step operation. * - * @param baseName the base name to use for output variables - * @param configuration the configuration for the sru - * @return + * @param baseName The base name for the lstm cell + * @param x Input, with shape [batchSize, inSize] + * @param cLast Previous cell state, with shape [batchSize, numUnits] + * @param yLast Previous cell output, with shape [batchSize, numUnits] + * @param weights The cell's weights. + * @param config The cell's config. + * @return The cell's outputs. */ - public SDVariable sru(String baseName, SRUConfiguration configuration) { - return new SRU(sd, configuration).outputVariables(baseName)[0]; + public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); + return new LSTMCellOutputs(c.outputVariables(baseName)); } /** - * An sru cell - * - * @param configuration the configuration for the sru cell - * @return + * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} */ - public SDVariable sruCell(SRUCellConfiguration configuration) { - return new SRUCell(sd, configuration).outputVariables()[0]; + public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength, + @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); + return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat()); } /** - * An sru cell - * - * @param baseName the base name to use for the output variables - * @param configuration the configuration for the sru cell - * @return + * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} */ - public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) { - return new SRUCell(sd, configuration).outputVariables(baseName)[0]; + public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + return lstmLayer( + sd.scalar("lstm_max_ts_length", maxTSLength), + x, cLast, yLast, weights, config); + } + + /** + * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} + */ + public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + if(baseName != null) { + return lstmLayer(baseName, + sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength), + x, cLast, yLast, weights, config); + } else { + return lstmLayer(maxTSLength, x, cLast, yLast, weights, config); + } + } + + /** + * The LSTM layer. Does multiple time steps. + * + * Input shape depends on data format (in config):
+ * TNS -> [timeSteps, batchSize, inSize]
+ * NST -> [batchSize, inSize, timeSteps]
+ * NTS -> [batchSize, timeSteps, inSize]
+ * + * @param baseName The base name for the lstm layer + * @param x Input, with shape dependent on the data format (in config). + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] + * @param weights The layer's weights. + * @param config The layer's config. + * @return The layer's outputs. + */ + public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength, + @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); + return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat()); + } + + /** + * See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}. + */ + public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { + return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables()); + } + + /** + * The SRU cell. Does a single time step operation. + * + * @param baseName The base name for the sru cell + * @param x Input, with shape [batchSize, inSize] + * @param cLast Previous cell state, with shape [batchSize, inSize] + * @param weights The cell's weights. + * @return The cell's outputs. + */ + public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { + return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName)); + } + + /** + * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} + */ + public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { + return sru(x, initialC, null, weights); + } + + /** + * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} + */ + public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { + return sru(baseName, x, initialC, null, weights); + } + + /** + * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} + */ + public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { + return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables()); + } + + /** + * The SRU layer. Does a single time step operation. + * + * @param baseName The base name for the sru layer + * @param x Input, with shape [batchSize, inSize, timeSeriesLength] + * @param initialC Initial cell state, with shape [batchSize, inSize] + * @param mask An optional dropout mask, with shape [batchSize, inSize] + * @param weights The layer's weights. + * @return The layer's outputs. + */ + public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { + return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java index 6c7daca69..2fa99ace5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import lombok.Getter; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -39,14 +41,15 @@ import java.util.Map; */ public class GRUCell extends DynamicCustomOp { - private GRUCellConfiguration configuration; + @Getter + private GRUWeights weights; public GRUCell() { } - public GRUCell(SameDiff sameDiff, GRUCellConfiguration configuration) { - super(null, sameDiff, configuration.args()); - this.configuration = configuration; + public GRUCell(SameDiff sameDiff, SDVariable x, SDVariable hLast, GRUWeights weights) { + super(null, sameDiff, weights.argsWithInputs(x, hLast)); + this.weights = weights; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java index 36512f610..f88625987 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java @@ -16,12 +16,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import lombok.Getter; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; +import org.nd4j.linalg.primitives.Pair; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -49,10 +52,12 @@ import java.util.Map; * 6: weights - cell peephole (t) connections to output gate, [numUnits]
* 7: biases, shape [4*numUnits]
*
- * Input integer arguments: set via {@link LSTMBlockCellConfiguration}
+ * Weights are set via {@link LSTMWeights}.
+ *
+ * Input integer arguments: set via {@link LSTMConfiguration}
* 0: if not zero, provide peephole connections
*
- * Input float arguments: set via {@link LSTMBlockCellConfiguration}
+ * Input float arguments: set via {@link LSTMConfiguration}
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
*
@@ -69,15 +74,19 @@ import java.util.Map; */ public class LSTMBlockCell extends DynamicCustomOp { - private LSTMBlockCellConfiguration configuration; + private LSTMConfiguration configuration; + + @Getter + private LSTMWeights weights; public LSTMBlockCell() { } - public LSTMBlockCell(SameDiff sameDiff, LSTMBlockCellConfiguration configuration) { - super(null, sameDiff, configuration.args()); + public LSTMBlockCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { + super(null, sameDiff, weights.argsWithInputs(x, cLast, yLast)); this.configuration = configuration; - addIArgument(configuration.iArgs()); + this.weights = weights; + addIArgument(configuration.iArgs(false)); addTArgument(configuration.tArgs()); } @@ -97,12 +106,12 @@ public class LSTMBlockCell extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - configuration = LSTMBlockCellConfiguration.builder() + configuration = LSTMConfiguration.builder() .forgetBias(attributesForNode.get("forget_bias").getF()) .clippingCellValue(attributesForNode.get("cell_clip").getF()) .peepHole(attributesForNode.get("use_peephole").getB()) .build(); - addIArgument(configuration.iArgs()); + addIArgument(configuration.iArgs(false)); addTArgument(configuration.tArgs()); } @@ -113,7 +122,7 @@ public class LSTMBlockCell extends DynamicCustomOp { @Override public Map propertiesForFunction() { - return configuration.toProperties(); + return configuration.toProperties(false); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java index 527c0c3ca..1e1ae3c47 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import lombok.Getter; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -24,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -75,13 +77,17 @@ public class LSTMLayer extends DynamicCustomOp { private LSTMConfiguration configuration; + @Getter + private LSTMWeights weights; + public LSTMLayer() { } - public LSTMLayer(@NonNull SameDiff sameDiff, @NonNull LSTMConfiguration configuration) { - super(null, sameDiff, configuration.args()); + public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { + super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast)); this.configuration = configuration; - addIArgument(configuration.iArgs()); + this.weights = weights; + addIArgument(configuration.iArgs(true)); addTArgument(configuration.tArgs()); } @@ -107,7 +113,7 @@ public class LSTMLayer extends DynamicCustomOp { .peepHole(attributesForNode.get("use_peephole").getB()) .dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM .build(); - addIArgument(configuration.iArgs()); + addIArgument(configuration.iArgs(true)); addTArgument(configuration.tArgs()); } @@ -118,7 +124,7 @@ public class LSTMLayer extends DynamicCustomOp { @Override public Map propertiesForFunction() { - return configuration.toProperties(); + return configuration.toProperties(true); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java index b916d4961..a2de2beb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java @@ -16,11 +16,16 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; import onnx.Onnx; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -34,13 +39,18 @@ import java.util.Map; */ public class SRU extends DynamicCustomOp { - private SRUConfiguration configuration; + @Getter + private SRUWeights weights; + + @Getter + private SDVariable mask; public SRU() { } - public SRU(SameDiff sameDiff, SRUConfiguration configuration) { - super(null, sameDiff, configuration.args()); - this.configuration = configuration; + public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { + super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask)); + this.mask = mask; + this.weights = weights; } @Override @@ -68,6 +78,4 @@ public class SRU extends DynamicCustomOp { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java index 4880b90fe..ac3f6c07f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java @@ -16,17 +16,18 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import java.util.Map; +import lombok.Getter; import onnx.Onnx; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.Map; - /** * A simple recurrent unit cell. * @@ -34,14 +35,15 @@ import java.util.Map; */ public class SRUCell extends DynamicCustomOp { - private SRUCellConfiguration configuration; + @Getter + private SRUWeights weights; public SRUCell() { } - public SRUCell(SameDiff sameDiff, SRUCellConfiguration configuration) { - super(null, sameDiff, configuration.args()); - this.configuration = configuration; + public SRUCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SRUWeights weights) { + super(null, sameDiff, weights.argsWithInputs(x, cLast)); + this.weights = weights; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java deleted file mode 100644 index 3e2591062..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.linalg.util.ArrayUtil; - -import java.util.LinkedHashMap; -import java.util.Map; - -@Builder -@Data -public class LSTMBlockCellConfiguration { - - private boolean peepHole; //IArg(0) - private double forgetBias; //TArg(0) - private double clippingCellValue; //TArg(1) - - private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b; - - public Map toProperties() { - Map ret = new LinkedHashMap<>(); - ret.put("peepHole",peepHole); - ret.put("clippingCellValue",clippingCellValue); - ret.put("forgetBias",forgetBias); - return ret; - } - - public SDVariable[] args() { - return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b}; - } - - - public int[] iArgs() { - return new int[] {ArrayUtil.fromBoolean(peepHole)}; - } - - public double[] tArgs() { - return new double[] {forgetBias,clippingCellValue}; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java index 98dc58876..4cf807765 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java @@ -19,13 +19,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; import lombok.Builder; import lombok.Data; import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; import org.nd4j.linalg.util.ArrayUtil; import java.util.LinkedHashMap; import java.util.Map; /** - * LSTM Configuration - for {@link org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer} + * LSTM Configuration - for {@link LSTMLayer} and {@link LSTMBlockCell} * * @author Alex Black */ @@ -33,29 +35,41 @@ import java.util.Map; @Data public class LSTMConfiguration { + /** + * Whether to provide peephole connections. + */ private boolean peepHole; //IArg(0) - @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1) + + /** + * The data format of the input. Only used in {@link LSTMLayer}, ignored in {@link LSTMBlockCell}. + */ + @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1) (only for lstmBlock, not lstmBlockCell) + + /** + * The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training. + */ private double forgetBias; //TArg(0) + + /** + * Clipping value for cell state, if it is not equal to zero, then cell state is clipped. + */ private double clippingCellValue; //TArg(1) - private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b; - - public Map toProperties() { + public Map toProperties(boolean includeDataFormat) { Map ret = new LinkedHashMap<>(); ret.put("peepHole",peepHole); ret.put("clippingCellValue",clippingCellValue); ret.put("forgetBias",forgetBias); - ret.put("dataFormat", dataFormat); + if(includeDataFormat) + ret.put("dataFormat", dataFormat); return ret; } - public SDVariable[] args() { - return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b}; - } - - public int[] iArgs() { - return new int[] {ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()}; + public int[] iArgs(boolean includeDataFormat) { + if(includeDataFormat) { + return new int[]{ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()}; + } else return new int[]{ArrayUtil.fromBoolean(peepHole)}; } public double[] tArgs() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java deleted file mode 100644 index 4b0a39a80..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; - -@Data -@Builder -public class SRUCellConfiguration { - /** - * - NDArray* xt = INPUT_VARIABLE(0); // input [batchSize x inSize], batchSize - batch size, inSize - number of features - NDArray* ct_1 = INPUT_VARIABLE(1); // previous cell state ct [batchSize x inSize], that is at previous time step t-1 - NDArray* w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] - NDArray* b = INPUT_VARIABLE(3); // biases [1 x 2*inSize] - - NDArray* ht = OUTPUT_VARIABLE(0); // current cell output [batchSize x inSize], that is at current time step t - NDArray* ct = OUTPUT_VARIABLE(1); // current cell state [batchSize x inSize], that is at current time step t - - */ - private SDVariable xt,ct_1,w,b,h1,ct; - - - public SDVariable[] args() { - return new SDVariable[] {xt,ct_1,w,b,h1,ct}; - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java deleted file mode 100644 index 8bfa90330..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; - -@Data -@Builder -public class SRUConfiguration { - /** - * NDArray* input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features - NDArray* weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] - NDArray* bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] - NDArray* init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 - - */ - private SDVariable inputs,weights,bias,init; - - public SDVariable[] args() { - return new SDVariable[] {inputs,weights,bias,init}; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java new file mode 100644 index 000000000..a39a5bcc7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java @@ -0,0 +1,62 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The outputs of a GRU cell ({@link GRUCell}. + */ +@Getter +public class GRUCellOutputs { + + /** + * Reset gate output [batchSize, numUnits]. + */ + private SDVariable r; + + /** + * Update gate output [batchSize, numUnits]. + */ + private SDVariable u; + + /** + * Cell gate output [batchSize, numUnits]. + */ + private SDVariable c; + + /** + * Current cell output [batchSize, numUnits]. + */ + private SDVariable h; + + public GRUCellOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 4, + "Must have 4 GRU cell outputs, got %s", outputs.length); + + r = outputs[0]; + u = outputs[1]; + c = outputs[2]; + h = outputs[3]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(r, u, c, h); + } + + /** + * Get h, the output of the cell. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getOutput(){ + return h; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java new file mode 100644 index 000000000..4fec87e8b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java @@ -0,0 +1,88 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; + +/** + * The outputs of a LSTM cell ({@link LSTMBlockCell}. + */ +@Getter +public class LSTMCellOutputs { + + /** + * Output - input modulation gate activations [batchSize, numUnits]. + */ + private SDVariable i; + + /** + * Activations, cell state (pre tanh) [batchSize, numUnits]. + */ + private SDVariable c; + + /** + * Output - forget gate activations [batchSize, numUnits]. + */ + private SDVariable f; + + /** + * Output - output gate activations [batchSize, numUnits]. + */ + private SDVariable o; + + /** + * Output - input gate activations [batchSize, numUnits]. + */ + private SDVariable z; + + /** + * Cell state, post tanh [batchSize, numUnits]. + */ + private SDVariable h; + + /** + * Current cell output [batchSize, numUnits]. + */ + private SDVariable y; + + public LSTMCellOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 7, + "Must have 7 LSTM cell outputs, got %s", outputs.length); + + i = outputs[0]; + c = outputs[1]; + f = outputs[2]; + o = outputs[3]; + z = outputs[4]; + h = outputs[5]; + y = outputs[6]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(i, c, f, o, z, h, y); + } + + /** + * Get y, the output of the cell. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getOutput(){ + return y; + } + + /** + * Get c, the cell's state. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getState(){ + return c; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java new file mode 100644 index 000000000..a01be219f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java @@ -0,0 +1,180 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.AccessLevel; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; + +/** + * The outputs of a LSTM layer ({@link LSTMLayer}. + */ +@Getter +public class LSTMLayerOutputs { + + private RnnDataFormat dataFormat; + + /** + * Output - input modulation gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable i; + + /** + * Activations, cell state (pre tanh). + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable c; + + /** + * Output - forget gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable f; + + /** + * Output - output gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable o; + + /** + * Output - input gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable z; + + /** + * Cell state, post tanh. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable h; + + /** + * Current cell output. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable y; + + public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){ + Preconditions.checkArgument(outputs.length == 7, + "Must have 7 LSTM layer outputs, got %s", outputs.length); + + i = outputs[0]; + c = outputs[1]; + f = outputs[2]; + o = outputs[3]; + z = outputs[4]; + h = outputs[5]; + y = outputs[6]; + this.dataFormat = dataFormat; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(i, c, f, o, z, h, y); + } + + /** + * Get y, the output of the cell for all time steps. + * + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + public SDVariable getOutput(){ + return y; + } + + /** + * Get c, the cell's state for all time steps. + * + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + public SDVariable getState(){ + return c; + } + + private SDVariable lastOutput = null; + + /** + * Get y, the output of the cell, for the last time step. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getLastOutput(){ + if(lastOutput != null) + return lastOutput; + + switch (dataFormat){ + case TNS: + lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); + break; + case NST: + lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + break; + case NTS: + lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); + break; + } + return lastOutput; + } + + private SDVariable lastState = null; + + /** + * Get c, the state of the cell, for the last time step. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getLastState(){ + if(lastState != null) + return lastState; + + switch (dataFormat){ + case TNS: + lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); + break; + case NST: + lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + break; + case NTS: + lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); + break; + } + return lastState; + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java new file mode 100644 index 000000000..d82ad63b1 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java @@ -0,0 +1,60 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The outputs of a GRU cell ({@link GRUCell}. + */ +@Getter +public class SRUCellOutputs { + + + /** + * Current cell output [batchSize, numUnits]. + */ + private SDVariable h; + + /** + * Current cell state [batchSize, numUnits]. + */ + private SDVariable c; + + public SRUCellOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 2, + "Must have 2 SRU cell outputs, got %s", outputs.length); + + h = outputs[0]; + c = outputs[1]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(h, c); + } + + /** + * Get h, the output of the cell. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getOutput(){ + return h; + } + + /** + * Get c, the state of the cell. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getState(){ + return c; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java new file mode 100644 index 000000000..281d2cc10 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java @@ -0,0 +1,92 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.AccessLevel; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The outputs of a GRU cell ({@link GRUCell}. + */ +@Getter +public class SRULayerOutputs { + + + /** + * Current cell output [batchSize, inSize, timeSeriesLength]. + */ + private SDVariable h; + + /** + * Current cell state [batchSize, inSize, timeSeriesLength]. + */ + private SDVariable c; + + public SRULayerOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 2, + "Must have 2 SRU cell outputs, got %s", outputs.length); + + h = outputs[0]; + c = outputs[1]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(h, c); + } + + /** + * Get h, the output of the cell. + * + * Has shape [batchSize, inSize, timeSeriesLength]. + */ + public SDVariable getOutput(){ + return h; + } + + /** + * Get c, the state of the cell. + * + * Has shape [batchSize, inSize, timeSeriesLength]. + */ + public SDVariable getState(){ + return c; + } + + private SDVariable lastOutput = null; + + /** + * Get y, the output of the cell, for the last time step. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getLastOutput(){ + if(lastOutput != null) + return lastOutput; + + lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + return lastOutput; + } + + private SDVariable lastState = null; + + /** + * Get c, the state of the cell, for the last time step. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getLastState(){ + if(lastState != null) + return lastState; + + lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + return lastState; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java new file mode 100644 index 000000000..f95438ae3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java @@ -0,0 +1,51 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The weight configuration of a GRU cell. For {@link GRUCell}. + * + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class GRUWeights extends RNNWeights { + + /** + * Reset and Update gate weights, with a shape of [inSize + numUnits, 2*numUnits]. + * + * The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset. + */ + @NonNull + private SDVariable ruWeight; + + /** + * Cell gate weights, with a shape of [inSize + numUnits, numUnits] + */ + @NonNull + private SDVariable cWeight; + + /** + * Reset and Update gate bias, with a shape of [2*numUnits]. May be null. + * + * The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset. + */ + @NonNull + private SDVariable ruBias; + + /** + * Cell gate bias, with a shape of [numUnits]. May be null. + */ + @NonNull + private SDVariable cBias; + + @Override + public SDVariable[] args() { + return filterNonNull(ruWeight, cWeight, ruBias, cBias); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java new file mode 100644 index 000000000..bf401d66c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java @@ -0,0 +1,57 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; + +/** + * The weight configuration of a LSTM layer. For {@link LSTMLayer} and {@link LSTMBlockCell}. + * + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class LSTMWeights extends RNNWeights { + + /** + * Input to hidden weights and hidden to hidden weights, with a shape of [inSize + numUnits, 4*numUnits]. + * + * Input to hidden and hidden to hidden are concatenated in dimension 0, + * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :]. + */ + @NonNull + private SDVariable weights; + + /** + * Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits]. + */ + @NonNull + private SDVariable inputPeepholeWeights; + + /** + * Cell peephole (t-1) connections to forget gate, with a shape of [numUnits]. + */ + @NonNull + private SDVariable forgetPeepholeWeights; + + /** + * Cell peephole (t) connections to output gate, with a shape of [numUnits]. + */ + @NonNull + private SDVariable outputPeepholeWeights; + + /** + * Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits]. + */ + @NonNull + private SDVariable bias; + + @Override + public SDVariable[] args() { + return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java new file mode 100644 index 000000000..62e295d80 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java @@ -0,0 +1,35 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import java.util.Arrays; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.util.ArrayUtil; + +public abstract class RNNWeights { + public abstract SDVariable[] args(); + + protected static SDVariable[] filterNonNull(SDVariable... args){ + int count = 0; + for(SDVariable v : args){ + if(v != null){ + count++; + } + } + + SDVariable[] res = new SDVariable[count]; + + int i = 0; + + for(SDVariable v : args){ + if(v != null){ + res[i] = v; + i++; + } + } + + return res; + } + + public SDVariable[] argsWithInputs(SDVariable... inputs){ + return ArrayUtil.combine(inputs, args()); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java new file mode 100644 index 000000000..821895f17 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java @@ -0,0 +1,37 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell; + +/** + * The weight configuration of a SRU layer. For {@link SRU} and {@link SRUCell}. + * + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class SRUWeights extends RNNWeights { + + /** + * Weights, with shape [inSize, 3*inSize]. + */ + @NonNull + private SDVariable weights; + + /** + * Biases, with shape [2*inSize]. + */ + @NonNull + private SDVariable bias; + + @Override + public SDVariable[] args() { + return new SDVariable[]{weights, bias}; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 2aae9eda1..8ecdc4eac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -16,14 +16,19 @@ package org.nd4j.autodiff.opvalidation; +import java.util.Arrays; import lombok.extern.slf4j.Slf4j; import org.junit.Test; +import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -59,23 +64,18 @@ public class RnnOpValidation extends BaseOpValidation { SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut)); double fb = 1.0; - LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder() - .xt(x) - .cLast(cLast) - .yLast(yLast) - .W(W) - .Wci(Wci) - .Wcf(Wcf) - .Wco(Wco) - .b(b) + LSTMConfiguration conf = LSTMConfiguration.builder() .peepHole(true) .forgetBias(fb) .clippingCellValue(0.0) .build(); - List v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y + LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) + .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); + + LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); - for(SDVariable sdv : v){ + for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.getVarName()); } @@ -167,23 +167,18 @@ public class RnnOpValidation extends BaseOpValidation { SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8)); double fb = 1.0; - LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder() - .xt(x) - .cLast(cLast) - .yLast(yLast) - .W(W) - .Wci(Wci) - .Wcf(Wcf) - .Wco(Wco) - .b(b) + LSTMConfiguration conf = LSTMConfiguration.builder() .peepHole(false) .forgetBias(fb) .clippingCellValue(0.0) .build(); - List v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y + LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) + .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); + + LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); - for(SDVariable sdv : v){ + for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.getVarName()); } @@ -228,16 +223,14 @@ public class RnnOpValidation extends BaseOpValidation { SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut)); double fb = 1.0; - GRUCellConfiguration conf = GRUCellConfiguration.builder() - .xt(x) - .hLast(hLast) - .Wru(Wru) - .Wc(Wc) - .bru(bru) - .bc(bc) + GRUWeights weights = GRUWeights.builder() + .ruWeight(Wru) + .cWeight(Wc) + .ruBias(bru) + .cBias(bc) .build(); - List v = sd.rnn().gru("gru", conf); + List v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ toExec.add(sdv.getVarName()); diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java index 8f5ca3888..31976080d 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java @@ -23,6 +23,7 @@ import lombok.NoArgsConstructor; import java.io.Serializable; import java.util.Arrays; +import org.nd4j.base.Preconditions; /** * Simple pair implementation @@ -86,4 +87,10 @@ public class Pair implements Serializable { public static Pair pairOf(T key, E value) { return new Pair(key, value); } + + public static Pair fromArray(T[] arr){ + Preconditions.checkArgument(arr.length == 2, + "Can only create a pair from an array with two values, got %s", arr.length); + return new Pair<>(arr[0], arr[1]); + } }