diff --git a/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp b/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp
index 2961e3bcf..23b2ec172 100644
--- a/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp
+++ b/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp
@@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
- auto b = INPUT_VARIABLE(3); // biases [1 x 2*inSize]
+ auto b = INPUT_VARIABLE(3); // biases [2*inSize]
auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
index 0b5a4c03f..1821a30a0 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
@@ -6511,4 +6511,22 @@ public class SameDiff extends SDBaseOps {
public String generateNewVarName(String base, int argIndex) {
return generateNewVarName(base, argIndex, true);
}
+
+ /**
+ * Returns an unused variable name of the format <base>_#.
+ *
+ * Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}.
+ */
+ public String generateDistinctCustomVariableName(String base){
+ if(!variables.containsKey(base))
+ return base;
+
+ int inc = 1;
+
+ while(variables.containsKey(base + "_" + inc)){
+ inc++;
+ }
+
+ return base + "_" + inc;
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java
index f47f32b87..de0114b92 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java
@@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.ops;
+import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
@@ -23,6 +24,15 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*;
import java.util.Arrays;
import java.util.List;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
+import org.nd4j.linalg.primitives.Pair;
/**
* SameDiff Recurrent Neural Network operations
@@ -39,90 +49,163 @@ public class SDRNN extends SDOps {
/**
- * The gru cell
- *
- * @param configuration the configuration to use
- * @return
+ * See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}.
*/
- public List gru(GRUCellConfiguration configuration) {
- GRUCell c = new GRUCell(sd, configuration);
- return Arrays.asList(c.outputVariables());
+ public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
+ GRUCell c = new GRUCell(sd, x, hLast, weights);
+ return new GRUCellOutputs(c.outputVariables());
}
/**
- * The gru cell
+ * The GRU cell. Does a single time step operation.
*
- * @param baseName the base name for the gru cell
- * @param configuration the configuration to use
- * @return
+ * @param baseName The base name for the gru cell
+ * @param x Input, with shape [batchSize, inSize]
+ * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits]
+ * @param weights The cell's weights.
+ * @return The cell's outputs.
*/
- public List gru(String baseName, GRUCellConfiguration configuration) {
- GRUCell c = new GRUCell(sd, configuration);
- return Arrays.asList(c.outputVariables(baseName));
- }
-
-
- /**
- * LSTM unit
- *
- * @param baseName the base name for outputs
- * @param configuration the configuration to use
- * @return
- */
- public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) {
- return new LSTMCell(sd, configuration).outputVariables(baseName)[0];
- }
-
- public List lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){
- SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name);
- return Arrays.asList(v);
- }
-
- public List lstmLayer(String name, LSTMConfiguration configuration){
- SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name);
- return Arrays.asList(v);
+ public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
+ GRUCell c = new GRUCell(sd, x, hLast, weights);
+ return new GRUCellOutputs(c.outputVariables(baseName));
}
/**
- * Simple recurrent unit
- *
- * @param configuration the configuration for the sru
- * @return
+ * See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}.
*/
- public SDVariable sru(SRUConfiguration configuration) {
- return new SRU(sd, configuration).outputVariables()[0];
+ public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
+ LSTMWeights weights, LSTMConfiguration config){
+ LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
+ return new LSTMCellOutputs(c.outputVariables());
}
/**
- * Simiple recurrent unit
+ * The LSTM cell. Does a single time step operation.
*
- * @param baseName the base name to use for output variables
- * @param configuration the configuration for the sru
- * @return
+ * @param baseName The base name for the lstm cell
+ * @param x Input, with shape [batchSize, inSize]
+ * @param cLast Previous cell state, with shape [batchSize, numUnits]
+ * @param yLast Previous cell output, with shape [batchSize, numUnits]
+ * @param weights The cell's weights.
+ * @param config The cell's config.
+ * @return The cell's outputs.
*/
- public SDVariable sru(String baseName, SRUConfiguration configuration) {
- return new SRU(sd, configuration).outputVariables(baseName)[0];
+ public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
+ @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
+ LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
+ return new LSTMCellOutputs(c.outputVariables(baseName));
}
/**
- * An sru cell
- *
- * @param configuration the configuration for the sru cell
- * @return
+ * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*/
- public SDVariable sruCell(SRUCellConfiguration configuration) {
- return new SRUCell(sd, configuration).outputVariables()[0];
+ public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength,
+ @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
+ @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
+ LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
+ return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat());
}
/**
- * An sru cell
- *
- * @param baseName the base name to use for the output variables
- * @param configuration the configuration for the sru cell
- * @return
+ * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*/
- public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) {
- return new SRUCell(sd, configuration).outputVariables(baseName)[0];
+ public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
+ @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
+ return lstmLayer(
+ sd.scalar("lstm_max_ts_length", maxTSLength),
+ x, cLast, yLast, weights, config);
+ }
+
+ /**
+ * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
+ */
+ public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
+ @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
+ if(baseName != null) {
+ return lstmLayer(baseName,
+ sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength),
+ x, cLast, yLast, weights, config);
+ } else {
+ return lstmLayer(maxTSLength, x, cLast, yLast, weights, config);
+ }
+ }
+
+ /**
+ * The LSTM layer. Does multiple time steps.
+ *
+ * Input shape depends on data format (in config):
+ * TNS -> [timeSteps, batchSize, inSize]
+ * NST -> [batchSize, inSize, timeSteps]
+ * NTS -> [batchSize, timeSteps, inSize]
+ *
+ * @param baseName The base name for the lstm layer
+ * @param x Input, with shape dependent on the data format (in config).
+ * @param cLast Previous/initial cell state, with shape [batchSize, numUnits]
+ * @param yLast Previous/initial cell output, with shape [batchSize, numUnits]
+ * @param weights The layer's weights.
+ * @param config The layer's config.
+ * @return The layer's outputs.
+ */
+ public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength,
+ @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
+ @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
+ LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
+ return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
+ }
+
+ /**
+ * See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}.
+ */
+ public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
+ return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables());
+ }
+
+ /**
+ * The SRU cell. Does a single time step operation.
+ *
+ * @param baseName The base name for the sru cell
+ * @param x Input, with shape [batchSize, inSize]
+ * @param cLast Previous cell state, with shape [batchSize, inSize]
+ * @param weights The cell's weights.
+ * @return The cell's outputs.
+ */
+ public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
+ return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName));
+ }
+
+ /**
+ * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
+ */
+ public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
+ return sru(x, initialC, null, weights);
+ }
+
+ /**
+ * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
+ */
+ public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
+ return sru(baseName, x, initialC, null, weights);
+ }
+
+ /**
+ * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
+ */
+ public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
+ return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables());
+ }
+
+ /**
+ * The SRU layer. Does a single time step operation.
+ *
+ * @param baseName The base name for the sru layer
+ * @param x Input, with shape [batchSize, inSize, timeSeriesLength]
+ * @param initialC Initial cell state, with shape [batchSize, inSize]
+ * @param mask An optional dropout mask, with shape [batchSize, inSize]
+ * @param weights The layer's weights.
+ * @return The layer's outputs.
+ */
+ public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
+ return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName));
}
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java
index 6c7daca69..2fa99ace5 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java
@@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
+import lombok.Getter;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
@@ -39,14 +41,15 @@ import java.util.Map;
*/
public class GRUCell extends DynamicCustomOp {
- private GRUCellConfiguration configuration;
+ @Getter
+ private GRUWeights weights;
public GRUCell() {
}
- public GRUCell(SameDiff sameDiff, GRUCellConfiguration configuration) {
- super(null, sameDiff, configuration.args());
- this.configuration = configuration;
+ public GRUCell(SameDiff sameDiff, SDVariable x, SDVariable hLast, GRUWeights weights) {
+ super(null, sameDiff, weights.argsWithInputs(x, hLast));
+ this.weights = weights;
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java
index 36512f610..f88625987 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java
@@ -16,12 +16,15 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
+import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
-import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
+import org.nd4j.linalg.primitives.Pair;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
@@ -49,10 +52,12 @@ import java.util.Map;
* 6: weights - cell peephole (t) connections to output gate, [numUnits]
* 7: biases, shape [4*numUnits]
*
- * Input integer arguments: set via {@link LSTMBlockCellConfiguration}
+ * Weights are set via {@link LSTMWeights}.
+ *
+ * Input integer arguments: set via {@link LSTMConfiguration}
* 0: if not zero, provide peephole connections
*
- * Input float arguments: set via {@link LSTMBlockCellConfiguration}
+ * Input float arguments: set via {@link LSTMConfiguration}
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
*
@@ -69,15 +74,19 @@ import java.util.Map;
*/
public class LSTMBlockCell extends DynamicCustomOp {
- private LSTMBlockCellConfiguration configuration;
+ private LSTMConfiguration configuration;
+
+ @Getter
+ private LSTMWeights weights;
public LSTMBlockCell() {
}
- public LSTMBlockCell(SameDiff sameDiff, LSTMBlockCellConfiguration configuration) {
- super(null, sameDiff, configuration.args());
+ public LSTMBlockCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
+ super(null, sameDiff, weights.argsWithInputs(x, cLast, yLast));
this.configuration = configuration;
- addIArgument(configuration.iArgs());
+ this.weights = weights;
+ addIArgument(configuration.iArgs(false));
addTArgument(configuration.tArgs());
}
@@ -97,12 +106,12 @@ public class LSTMBlockCell extends DynamicCustomOp {
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
- configuration = LSTMBlockCellConfiguration.builder()
+ configuration = LSTMConfiguration.builder()
.forgetBias(attributesForNode.get("forget_bias").getF())
.clippingCellValue(attributesForNode.get("cell_clip").getF())
.peepHole(attributesForNode.get("use_peephole").getB())
.build();
- addIArgument(configuration.iArgs());
+ addIArgument(configuration.iArgs(false));
addTArgument(configuration.tArgs());
}
@@ -113,7 +122,7 @@ public class LSTMBlockCell extends DynamicCustomOp {
@Override
public Map propertiesForFunction() {
- return configuration.toProperties();
+ return configuration.toProperties(false);
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java
index 527c0c3ca..1e1ae3c47 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java
@@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
+import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@@ -24,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
@@ -75,13 +77,17 @@ public class LSTMLayer extends DynamicCustomOp {
private LSTMConfiguration configuration;
+ @Getter
+ private LSTMWeights weights;
+
public LSTMLayer() {
}
- public LSTMLayer(@NonNull SameDiff sameDiff, @NonNull LSTMConfiguration configuration) {
- super(null, sameDiff, configuration.args());
+ public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
+ super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast));
this.configuration = configuration;
- addIArgument(configuration.iArgs());
+ this.weights = weights;
+ addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
@@ -107,7 +113,7 @@ public class LSTMLayer extends DynamicCustomOp {
.peepHole(attributesForNode.get("use_peephole").getB())
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
.build();
- addIArgument(configuration.iArgs());
+ addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
@@ -118,7 +124,7 @@ public class LSTMLayer extends DynamicCustomOp {
@Override
public Map propertiesForFunction() {
- return configuration.toProperties();
+ return configuration.toProperties(true);
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java
index b916d4961..a2de2beb8 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java
@@ -16,11 +16,16 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
+import java.util.Arrays;
+import java.util.List;
+import lombok.Getter;
+import lombok.NonNull;
import onnx.Onnx;
+import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
-import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
@@ -34,13 +39,18 @@ import java.util.Map;
*/
public class SRU extends DynamicCustomOp {
- private SRUConfiguration configuration;
+ @Getter
+ private SRUWeights weights;
+
+ @Getter
+ private SDVariable mask;
public SRU() { }
- public SRU(SameDiff sameDiff, SRUConfiguration configuration) {
- super(null, sameDiff, configuration.args());
- this.configuration = configuration;
+ public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
+ super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask));
+ this.mask = mask;
+ this.weights = weights;
}
@Override
@@ -68,6 +78,4 @@ public class SRU extends DynamicCustomOp {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}
-
-
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java
index 4880b90fe..ac3f6c07f 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java
@@ -16,17 +16,18 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
+import java.util.Map;
+import lombok.Getter;
import onnx.Onnx;
+import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
-import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
-import java.util.Map;
-
/**
* A simple recurrent unit cell.
*
@@ -34,14 +35,15 @@ import java.util.Map;
*/
public class SRUCell extends DynamicCustomOp {
- private SRUCellConfiguration configuration;
+ @Getter
+ private SRUWeights weights;
public SRUCell() {
}
- public SRUCell(SameDiff sameDiff, SRUCellConfiguration configuration) {
- super(null, sameDiff, configuration.args());
- this.configuration = configuration;
+ public SRUCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SRUWeights weights) {
+ super(null, sameDiff, weights.argsWithInputs(x, cLast));
+ this.weights = weights;
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java
deleted file mode 100644
index 3e2591062..000000000
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2019 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
-
-import lombok.Builder;
-import lombok.Data;
-import org.nd4j.autodiff.samediff.SDVariable;
-import org.nd4j.linalg.util.ArrayUtil;
-
-import java.util.LinkedHashMap;
-import java.util.Map;
-
-@Builder
-@Data
-public class LSTMBlockCellConfiguration {
-
- private boolean peepHole; //IArg(0)
- private double forgetBias; //TArg(0)
- private double clippingCellValue; //TArg(1)
-
- private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b;
-
- public Map toProperties() {
- Map ret = new LinkedHashMap<>();
- ret.put("peepHole",peepHole);
- ret.put("clippingCellValue",clippingCellValue);
- ret.put("forgetBias",forgetBias);
- return ret;
- }
-
- public SDVariable[] args() {
- return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
- }
-
-
- public int[] iArgs() {
- return new int[] {ArrayUtil.fromBoolean(peepHole)};
- }
-
- public double[] tArgs() {
- return new double[] {forgetBias,clippingCellValue};
- }
-}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java
index 98dc58876..4cf807765 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java
@@ -19,13 +19,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
import lombok.Builder;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.LinkedHashMap;
import java.util.Map;
/**
- * LSTM Configuration - for {@link org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer}
+ * LSTM Configuration - for {@link LSTMLayer} and {@link LSTMBlockCell}
*
* @author Alex Black
*/
@@ -33,29 +35,41 @@ import java.util.Map;
@Data
public class LSTMConfiguration {
+ /**
+ * Whether to provide peephole connections.
+ */
private boolean peepHole; //IArg(0)
- @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1)
+
+ /**
+ * The data format of the input. Only used in {@link LSTMLayer}, ignored in {@link LSTMBlockCell}.
+ */
+ @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1) (only for lstmBlock, not lstmBlockCell)
+
+ /**
+ * The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training.
+ */
private double forgetBias; //TArg(0)
+
+ /**
+ * Clipping value for cell state, if it is not equal to zero, then cell state is clipped.
+ */
private double clippingCellValue; //TArg(1)
- private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b;
-
- public Map toProperties() {
+ public Map toProperties(boolean includeDataFormat) {
Map ret = new LinkedHashMap<>();
ret.put("peepHole",peepHole);
ret.put("clippingCellValue",clippingCellValue);
ret.put("forgetBias",forgetBias);
- ret.put("dataFormat", dataFormat);
+ if(includeDataFormat)
+ ret.put("dataFormat", dataFormat);
return ret;
}
- public SDVariable[] args() {
- return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
- }
-
- public int[] iArgs() {
- return new int[] {ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()};
+ public int[] iArgs(boolean includeDataFormat) {
+ if(includeDataFormat) {
+ return new int[]{ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()};
+ } else return new int[]{ArrayUtil.fromBoolean(peepHole)};
}
public double[] tArgs() {
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java
deleted file mode 100644
index 4b0a39a80..000000000
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java
+++ /dev/null
@@ -1,44 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
-
-import lombok.Builder;
-import lombok.Data;
-import org.nd4j.autodiff.samediff.SDVariable;
-
-@Data
-@Builder
-public class SRUCellConfiguration {
- /**
- *
- NDArray* xt = INPUT_VARIABLE(0); // input [batchSize x inSize], batchSize - batch size, inSize - number of features
- NDArray* ct_1 = INPUT_VARIABLE(1); // previous cell state ct [batchSize x inSize], that is at previous time step t-1
- NDArray* w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
- NDArray* b = INPUT_VARIABLE(3); // biases [1 x 2*inSize]
-
- NDArray* ht = OUTPUT_VARIABLE(0); // current cell output [batchSize x inSize], that is at current time step t
- NDArray* ct = OUTPUT_VARIABLE(1); // current cell state [batchSize x inSize], that is at current time step t
-
- */
- private SDVariable xt,ct_1,w,b,h1,ct;
-
-
- public SDVariable[] args() {
- return new SDVariable[] {xt,ct_1,w,b,h1,ct};
- }
-
-}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java
deleted file mode 100644
index 8bfa90330..000000000
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
-
-import lombok.Builder;
-import lombok.Data;
-import org.nd4j.autodiff.samediff.SDVariable;
-
-@Data
-@Builder
-public class SRUConfiguration {
- /**
- * NDArray* input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
- NDArray* weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
- NDArray* bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K]
- NDArray* init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
-
- */
- private SDVariable inputs,weights,bias,init;
-
- public SDVariable[] args() {
- return new SDVariable[] {inputs,weights,bias,init};
- }
-}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java
new file mode 100644
index 000000000..a39a5bcc7
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java
@@ -0,0 +1,62 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
+
+import java.util.Arrays;
+import java.util.List;
+import lombok.Getter;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
+
+/**
+ * The outputs of a GRU cell ({@link GRUCell}.
+ */
+@Getter
+public class GRUCellOutputs {
+
+ /**
+ * Reset gate output [batchSize, numUnits].
+ */
+ private SDVariable r;
+
+ /**
+ * Update gate output [batchSize, numUnits].
+ */
+ private SDVariable u;
+
+ /**
+ * Cell gate output [batchSize, numUnits].
+ */
+ private SDVariable c;
+
+ /**
+ * Current cell output [batchSize, numUnits].
+ */
+ private SDVariable h;
+
+ public GRUCellOutputs(SDVariable[] outputs){
+ Preconditions.checkArgument(outputs.length == 4,
+ "Must have 4 GRU cell outputs, got %s", outputs.length);
+
+ r = outputs[0];
+ u = outputs[1];
+ c = outputs[2];
+ h = outputs[3];
+ }
+
+ /**
+ * Get all outputs returned by the cell.
+ */
+ public List getAllOutputs(){
+ return Arrays.asList(r, u, c, h);
+ }
+
+ /**
+ * Get h, the output of the cell.
+ *
+ * Has shape [batchSize, numUnits].
+ */
+ public SDVariable getOutput(){
+ return h;
+ }
+
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java
new file mode 100644
index 000000000..4fec87e8b
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java
@@ -0,0 +1,88 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
+
+import java.util.Arrays;
+import java.util.List;
+import lombok.Getter;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
+
+/**
+ * The outputs of a LSTM cell ({@link LSTMBlockCell}.
+ */
+@Getter
+public class LSTMCellOutputs {
+
+ /**
+ * Output - input modulation gate activations [batchSize, numUnits].
+ */
+ private SDVariable i;
+
+ /**
+ * Activations, cell state (pre tanh) [batchSize, numUnits].
+ */
+ private SDVariable c;
+
+ /**
+ * Output - forget gate activations [batchSize, numUnits].
+ */
+ private SDVariable f;
+
+ /**
+ * Output - output gate activations [batchSize, numUnits].
+ */
+ private SDVariable o;
+
+ /**
+ * Output - input gate activations [batchSize, numUnits].
+ */
+ private SDVariable z;
+
+ /**
+ * Cell state, post tanh [batchSize, numUnits].
+ */
+ private SDVariable h;
+
+ /**
+ * Current cell output [batchSize, numUnits].
+ */
+ private SDVariable y;
+
+ public LSTMCellOutputs(SDVariable[] outputs){
+ Preconditions.checkArgument(outputs.length == 7,
+ "Must have 7 LSTM cell outputs, got %s", outputs.length);
+
+ i = outputs[0];
+ c = outputs[1];
+ f = outputs[2];
+ o = outputs[3];
+ z = outputs[4];
+ h = outputs[5];
+ y = outputs[6];
+ }
+
+ /**
+ * Get all outputs returned by the cell.
+ */
+ public List getAllOutputs(){
+ return Arrays.asList(i, c, f, o, z, h, y);
+ }
+
+ /**
+ * Get y, the output of the cell.
+ *
+ * Has shape [batchSize, numUnits].
+ */
+ public SDVariable getOutput(){
+ return y;
+ }
+
+ /**
+ * Get c, the cell's state.
+ *
+ * Has shape [batchSize, numUnits].
+ */
+ public SDVariable getState(){
+ return c;
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java
new file mode 100644
index 000000000..a01be219f
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java
@@ -0,0 +1,180 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
+
+import java.util.Arrays;
+import java.util.List;
+import lombok.AccessLevel;
+import lombok.Getter;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
+
+/**
+ * The outputs of a LSTM layer ({@link LSTMLayer}.
+ */
+@Getter
+public class LSTMLayerOutputs {
+
+ private RnnDataFormat dataFormat;
+
+ /**
+ * Output - input modulation gate activations.
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ private SDVariable i;
+
+ /**
+ * Activations, cell state (pre tanh).
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ private SDVariable c;
+
+ /**
+ * Output - forget gate activations.
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ private SDVariable f;
+
+ /**
+ * Output - output gate activations.
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ private SDVariable o;
+
+ /**
+ * Output - input gate activations.
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ private SDVariable z;
+
+ /**
+ * Cell state, post tanh.
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ private SDVariable h;
+
+ /**
+ * Current cell output.
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ private SDVariable y;
+
+ public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){
+ Preconditions.checkArgument(outputs.length == 7,
+ "Must have 7 LSTM layer outputs, got %s", outputs.length);
+
+ i = outputs[0];
+ c = outputs[1];
+ f = outputs[2];
+ o = outputs[3];
+ z = outputs[4];
+ h = outputs[5];
+ y = outputs[6];
+ this.dataFormat = dataFormat;
+ }
+
+ /**
+ * Get all outputs returned by the cell.
+ */
+ public List getAllOutputs(){
+ return Arrays.asList(i, c, f, o, z, h, y);
+ }
+
+ /**
+ * Get y, the output of the cell for all time steps.
+ *
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ public SDVariable getOutput(){
+ return y;
+ }
+
+ /**
+ * Get c, the cell's state for all time steps.
+ *
+ * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */
+ public SDVariable getState(){
+ return c;
+ }
+
+ private SDVariable lastOutput = null;
+
+ /**
+ * Get y, the output of the cell, for the last time step.
+ *
+ * Has shape [batchSize, numUnits].
+ */
+ public SDVariable getLastOutput(){
+ if(lastOutput != null)
+ return lastOutput;
+
+ switch (dataFormat){
+ case TNS:
+ lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
+ break;
+ case NST:
+ lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
+ break;
+ case NTS:
+ lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
+ break;
+ }
+ return lastOutput;
+ }
+
+ private SDVariable lastState = null;
+
+ /**
+ * Get c, the state of the cell, for the last time step.
+ *
+ * Has shape [batchSize, numUnits].
+ */
+ public SDVariable getLastState(){
+ if(lastState != null)
+ return lastState;
+
+ switch (dataFormat){
+ case TNS:
+ lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
+ break;
+ case NST:
+ lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
+ break;
+ case NTS:
+ lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
+ break;
+ }
+ return lastState;
+ }
+
+
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java
new file mode 100644
index 000000000..d82ad63b1
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java
@@ -0,0 +1,60 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
+
+import java.util.Arrays;
+import java.util.List;
+import lombok.Getter;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
+
+/**
+ * The outputs of a GRU cell ({@link GRUCell}.
+ */
+@Getter
+public class SRUCellOutputs {
+
+
+ /**
+ * Current cell output [batchSize, numUnits].
+ */
+ private SDVariable h;
+
+ /**
+ * Current cell state [batchSize, numUnits].
+ */
+ private SDVariable c;
+
+ public SRUCellOutputs(SDVariable[] outputs){
+ Preconditions.checkArgument(outputs.length == 2,
+ "Must have 2 SRU cell outputs, got %s", outputs.length);
+
+ h = outputs[0];
+ c = outputs[1];
+ }
+
+ /**
+ * Get all outputs returned by the cell.
+ */
+ public List getAllOutputs(){
+ return Arrays.asList(h, c);
+ }
+
+ /**
+ * Get h, the output of the cell.
+ *
+ * Has shape [batchSize, inSize].
+ */
+ public SDVariable getOutput(){
+ return h;
+ }
+
+ /**
+ * Get c, the state of the cell.
+ *
+ * Has shape [batchSize, inSize].
+ */
+ public SDVariable getState(){
+ return c;
+ }
+
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java
new file mode 100644
index 000000000..281d2cc10
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java
@@ -0,0 +1,92 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
+
+import java.util.Arrays;
+import java.util.List;
+import lombok.AccessLevel;
+import lombok.Getter;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
+
+/**
+ * The outputs of a GRU cell ({@link GRUCell}.
+ */
+@Getter
+public class SRULayerOutputs {
+
+
+ /**
+ * Current cell output [batchSize, inSize, timeSeriesLength].
+ */
+ private SDVariable h;
+
+ /**
+ * Current cell state [batchSize, inSize, timeSeriesLength].
+ */
+ private SDVariable c;
+
+ public SRULayerOutputs(SDVariable[] outputs){
+ Preconditions.checkArgument(outputs.length == 2,
+ "Must have 2 SRU cell outputs, got %s", outputs.length);
+
+ h = outputs[0];
+ c = outputs[1];
+ }
+
+ /**
+ * Get all outputs returned by the cell.
+ */
+ public List getAllOutputs(){
+ return Arrays.asList(h, c);
+ }
+
+ /**
+ * Get h, the output of the cell.
+ *
+ * Has shape [batchSize, inSize, timeSeriesLength].
+ */
+ public SDVariable getOutput(){
+ return h;
+ }
+
+ /**
+ * Get c, the state of the cell.
+ *
+ * Has shape [batchSize, inSize, timeSeriesLength].
+ */
+ public SDVariable getState(){
+ return c;
+ }
+
+ private SDVariable lastOutput = null;
+
+ /**
+ * Get y, the output of the cell, for the last time step.
+ *
+ * Has shape [batchSize, inSize].
+ */
+ public SDVariable getLastOutput(){
+ if(lastOutput != null)
+ return lastOutput;
+
+ lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
+ return lastOutput;
+ }
+
+ private SDVariable lastState = null;
+
+ /**
+ * Get c, the state of the cell, for the last time step.
+ *
+ * Has shape [batchSize, inSize].
+ */
+ public SDVariable getLastState(){
+ if(lastState != null)
+ return lastState;
+
+ lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
+ return lastState;
+ }
+
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java
new file mode 100644
index 000000000..f95438ae3
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java
@@ -0,0 +1,51 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
+
+/**
+ * The weight configuration of a GRU cell. For {@link GRUCell}.
+ *
+ */
+@EqualsAndHashCode(callSuper = true)
+@Data
+@Builder
+public class GRUWeights extends RNNWeights {
+
+ /**
+ * Reset and Update gate weights, with a shape of [inSize + numUnits, 2*numUnits].
+ *
+ * The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset.
+ */
+ @NonNull
+ private SDVariable ruWeight;
+
+ /**
+ * Cell gate weights, with a shape of [inSize + numUnits, numUnits]
+ */
+ @NonNull
+ private SDVariable cWeight;
+
+ /**
+ * Reset and Update gate bias, with a shape of [2*numUnits]. May be null.
+ *
+ * The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset.
+ */
+ @NonNull
+ private SDVariable ruBias;
+
+ /**
+ * Cell gate bias, with a shape of [numUnits]. May be null.
+ */
+ @NonNull
+ private SDVariable cBias;
+
+ @Override
+ public SDVariable[] args() {
+ return filterNonNull(ruWeight, cWeight, ruBias, cBias);
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java
new file mode 100644
index 000000000..bf401d66c
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java
@@ -0,0 +1,57 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
+
+/**
+ * The weight configuration of a LSTM layer. For {@link LSTMLayer} and {@link LSTMBlockCell}.
+ *
+ */
+@EqualsAndHashCode(callSuper = true)
+@Data
+@Builder
+public class LSTMWeights extends RNNWeights {
+
+ /**
+ * Input to hidden weights and hidden to hidden weights, with a shape of [inSize + numUnits, 4*numUnits].
+ *
+ * Input to hidden and hidden to hidden are concatenated in dimension 0,
+ * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :].
+ */
+ @NonNull
+ private SDVariable weights;
+
+ /**
+ * Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits].
+ */
+ @NonNull
+ private SDVariable inputPeepholeWeights;
+
+ /**
+ * Cell peephole (t-1) connections to forget gate, with a shape of [numUnits].
+ */
+ @NonNull
+ private SDVariable forgetPeepholeWeights;
+
+ /**
+ * Cell peephole (t) connections to output gate, with a shape of [numUnits].
+ */
+ @NonNull
+ private SDVariable outputPeepholeWeights;
+
+ /**
+ * Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits].
+ */
+ @NonNull
+ private SDVariable bias;
+
+ @Override
+ public SDVariable[] args() {
+ return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias);
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java
new file mode 100644
index 000000000..62e295d80
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java
@@ -0,0 +1,35 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
+
+import java.util.Arrays;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.linalg.util.ArrayUtil;
+
+public abstract class RNNWeights {
+ public abstract SDVariable[] args();
+
+ protected static SDVariable[] filterNonNull(SDVariable... args){
+ int count = 0;
+ for(SDVariable v : args){
+ if(v != null){
+ count++;
+ }
+ }
+
+ SDVariable[] res = new SDVariable[count];
+
+ int i = 0;
+
+ for(SDVariable v : args){
+ if(v != null){
+ res[i] = v;
+ i++;
+ }
+ }
+
+ return res;
+ }
+
+ public SDVariable[] argsWithInputs(SDVariable... inputs){
+ return ArrayUtil.combine(inputs, args());
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java
new file mode 100644
index 000000000..821895f17
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java
@@ -0,0 +1,37 @@
+package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
+
+/**
+ * The weight configuration of a SRU layer. For {@link SRU} and {@link SRUCell}.
+ *
+ */
+@EqualsAndHashCode(callSuper = true)
+@Data
+@Builder
+public class SRUWeights extends RNNWeights {
+
+ /**
+ * Weights, with shape [inSize, 3*inSize].
+ */
+ @NonNull
+ private SDVariable weights;
+
+ /**
+ * Biases, with shape [2*inSize].
+ */
+ @NonNull
+ private SDVariable bias;
+
+ @Override
+ public SDVariable[] args() {
+ return new SDVariable[]{weights, bias};
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java
index 2aae9eda1..8ecdc4eac 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java
@@ -16,14 +16,19 @@
package org.nd4j.autodiff.opvalidation;
+import java.util.Arrays;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
+import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
-import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
@@ -59,23 +64,18 @@ public class RnnOpValidation extends BaseOpValidation {
SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut));
double fb = 1.0;
- LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder()
- .xt(x)
- .cLast(cLast)
- .yLast(yLast)
- .W(W)
- .Wci(Wci)
- .Wcf(Wcf)
- .Wco(Wco)
- .b(b)
+ LSTMConfiguration conf = LSTMConfiguration.builder()
.peepHole(true)
.forgetBias(fb)
.clippingCellValue(0.0)
.build();
- List v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y
+ LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
+ .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
+
+ LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
List toExec = new ArrayList<>();
- for(SDVariable sdv : v){
+ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.getVarName());
}
@@ -167,23 +167,18 @@ public class RnnOpValidation extends BaseOpValidation {
SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8));
double fb = 1.0;
- LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder()
- .xt(x)
- .cLast(cLast)
- .yLast(yLast)
- .W(W)
- .Wci(Wci)
- .Wcf(Wcf)
- .Wco(Wco)
- .b(b)
+ LSTMConfiguration conf = LSTMConfiguration.builder()
.peepHole(false)
.forgetBias(fb)
.clippingCellValue(0.0)
.build();
- List v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y
+ LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
+ .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
+
+ LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
List toExec = new ArrayList<>();
- for(SDVariable sdv : v){
+ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.getVarName());
}
@@ -228,16 +223,14 @@ public class RnnOpValidation extends BaseOpValidation {
SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
double fb = 1.0;
- GRUCellConfiguration conf = GRUCellConfiguration.builder()
- .xt(x)
- .hLast(hLast)
- .Wru(Wru)
- .Wc(Wc)
- .bru(bru)
- .bc(bc)
+ GRUWeights weights = GRUWeights.builder()
+ .ruWeight(Wru)
+ .cWeight(Wc)
+ .ruBias(bru)
+ .cBias(bc)
.build();
- List v = sd.rnn().gru("gru", conf);
+ List v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs();
List toExec = new ArrayList<>();
for(SDVariable sdv : v){
toExec.add(sdv.getVarName());
diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java
index 8f5ca3888..31976080d 100644
--- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java
+++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java
@@ -23,6 +23,7 @@ import lombok.NoArgsConstructor;
import java.io.Serializable;
import java.util.Arrays;
+import org.nd4j.base.Preconditions;
/**
* Simple pair implementation
@@ -86,4 +87,10 @@ public class Pair implements Serializable {
public static Pair pairOf(T key, E value) {
return new Pair(key, value);
}
+
+ public static Pair fromArray(T[] arr){
+ Preconditions.checkArgument(arr.length == 2,
+ "Can only create a pair from an array with two values, got %s", arr.length);
+ return new Pair<>(arr[0], arr[1]);
+ }
}