cleanup SDRNN and rnn ops (#238)
Signed-off-by: Ryan Nett <rnett@skymind.io>
This commit is contained in:
		
							parent
							
								
									7d85775934
								
							
						
					
					
						commit
						79867f5c5a
					
				@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
 | 
			
		||||
    auto  xt   = INPUT_VARIABLE(0);               // input [bS x inSize], bS - batch size, inSize - number of features
 | 
			
		||||
    auto ct_1 = INPUT_VARIABLE(1);               // previous cell state ct  [bS x inSize], that is at previous time step t-1
 | 
			
		||||
    auto w    = INPUT_VARIABLE(2);               // weights [inSize x 3*inSize]
 | 
			
		||||
    auto b    = INPUT_VARIABLE(3);               // biases [1 x 2*inSize]
 | 
			
		||||
    auto b    = INPUT_VARIABLE(3);               // biases [2*inSize]
 | 
			
		||||
 | 
			
		||||
    auto ht   = OUTPUT_VARIABLE(0);              // current cell output [bS x inSize], that is at current time step t
 | 
			
		||||
    auto ct   = OUTPUT_VARIABLE(1);              // current cell state  [bS x inSize], that is at current time step t
 | 
			
		||||
 | 
			
		||||
@ -6511,4 +6511,22 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
    public String generateNewVarName(String base, int argIndex) {
 | 
			
		||||
        return generateNewVarName(base, argIndex, true);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Returns an unused variable name of the format <base>_#.
 | 
			
		||||
     *
 | 
			
		||||
     * Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}.
 | 
			
		||||
     */
 | 
			
		||||
    public String generateDistinctCustomVariableName(String base){
 | 
			
		||||
        if(!variables.containsKey(base))
 | 
			
		||||
            return base;
 | 
			
		||||
 | 
			
		||||
        int inc = 1;
 | 
			
		||||
 | 
			
		||||
        while(variables.containsKey(base + "_" + inc)){
 | 
			
		||||
            inc++;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return base + "_" + inc;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.autodiff.samediff.ops;
 | 
			
		||||
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
 | 
			
		||||
@ -23,6 +24,15 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
 | 
			
		||||
import org.nd4j.linalg.primitives.Pair;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * SameDiff Recurrent Neural Network operations<br>
 | 
			
		||||
@ -39,90 +49,163 @@ public class SDRNN extends SDOps {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The gru cell
 | 
			
		||||
     *
 | 
			
		||||
     * @param configuration the configuration to use
 | 
			
		||||
     * @return
 | 
			
		||||
     * See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}.
 | 
			
		||||
     */
 | 
			
		||||
    public List<SDVariable> gru(GRUCellConfiguration configuration) {
 | 
			
		||||
        GRUCell c = new GRUCell(sd, configuration);
 | 
			
		||||
        return Arrays.asList(c.outputVariables());
 | 
			
		||||
    public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
 | 
			
		||||
        GRUCell c = new GRUCell(sd, x, hLast, weights);
 | 
			
		||||
        return new GRUCellOutputs(c.outputVariables());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The gru cell
 | 
			
		||||
     * The GRU cell.  Does a single time step operation.
 | 
			
		||||
     *
 | 
			
		||||
     * @param baseName      the base name for the gru cell
 | 
			
		||||
     * @param configuration the configuration to use
 | 
			
		||||
     * @return
 | 
			
		||||
     * @param baseName  The base name for the gru cell
 | 
			
		||||
     * @param x         Input, with shape [batchSize, inSize]
 | 
			
		||||
     * @param hLast     Output of the previous cell/time step, with shape [batchSize, numUnits]
 | 
			
		||||
     * @param weights   The cell's weights.
 | 
			
		||||
     * @return          The cell's outputs.
 | 
			
		||||
     */
 | 
			
		||||
    public List<SDVariable> gru(String baseName, GRUCellConfiguration configuration) {
 | 
			
		||||
        GRUCell c = new GRUCell(sd, configuration);
 | 
			
		||||
        return Arrays.asList(c.outputVariables(baseName));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * LSTM unit
 | 
			
		||||
     *
 | 
			
		||||
     * @param baseName      the base name for outputs
 | 
			
		||||
     * @param configuration the configuration to use
 | 
			
		||||
     * @return
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) {
 | 
			
		||||
        return new LSTMCell(sd, configuration).outputVariables(baseName)[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public List<SDVariable> lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){
 | 
			
		||||
        SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name);
 | 
			
		||||
        return Arrays.asList(v);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public List<SDVariable> lstmLayer(String name, LSTMConfiguration configuration){
 | 
			
		||||
        SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name);
 | 
			
		||||
        return Arrays.asList(v);
 | 
			
		||||
    public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
 | 
			
		||||
        GRUCell c = new GRUCell(sd, x, hLast, weights);
 | 
			
		||||
        return new GRUCellOutputs(c.outputVariables(baseName));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Simple recurrent unit
 | 
			
		||||
     *
 | 
			
		||||
     * @param configuration the configuration for the sru
 | 
			
		||||
     * @return
 | 
			
		||||
     * See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}.
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable sru(SRUConfiguration configuration) {
 | 
			
		||||
        return new SRU(sd, configuration).outputVariables()[0];
 | 
			
		||||
    public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
 | 
			
		||||
            LSTMWeights weights, LSTMConfiguration config){
 | 
			
		||||
        LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
 | 
			
		||||
        return new LSTMCellOutputs(c.outputVariables());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Simiple recurrent unit
 | 
			
		||||
     * The LSTM cell.  Does a single time step operation.
 | 
			
		||||
     *
 | 
			
		||||
     * @param baseName      the base name to use for output variables
 | 
			
		||||
     * @param configuration the configuration for the sru
 | 
			
		||||
     * @return
 | 
			
		||||
     * @param baseName  The base name for the lstm cell
 | 
			
		||||
     * @param x         Input, with shape [batchSize, inSize]
 | 
			
		||||
     * @param cLast     Previous cell state, with shape [batchSize, numUnits]
 | 
			
		||||
     * @param yLast     Previous cell output, with shape [batchSize, numUnits]
 | 
			
		||||
     * @param weights   The cell's weights.
 | 
			
		||||
     * @param config    The cell's config.
 | 
			
		||||
     * @return          The cell's outputs.
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable sru(String baseName, SRUConfiguration configuration) {
 | 
			
		||||
        return new SRU(sd, configuration).outputVariables(baseName)[0];
 | 
			
		||||
    public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
 | 
			
		||||
            @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
 | 
			
		||||
        LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
 | 
			
		||||
        return new LSTMCellOutputs(c.outputVariables(baseName));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * An sru cell
 | 
			
		||||
     *
 | 
			
		||||
     * @param configuration the configuration for the sru cell
 | 
			
		||||
     * @return
 | 
			
		||||
     * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable sruCell(SRUCellConfiguration configuration) {
 | 
			
		||||
        return new SRUCell(sd, configuration).outputVariables()[0];
 | 
			
		||||
    public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength,
 | 
			
		||||
            @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
 | 
			
		||||
            @NonNull  LSTMWeights weights, @NonNull LSTMConfiguration config){
 | 
			
		||||
        LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
 | 
			
		||||
        return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * An sru cell
 | 
			
		||||
     *
 | 
			
		||||
     * @param baseName      the base name to  use for the output variables
 | 
			
		||||
     * @param configuration the configuration for the sru cell
 | 
			
		||||
     * @return
 | 
			
		||||
     * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) {
 | 
			
		||||
        return new SRUCell(sd, configuration).outputVariables(baseName)[0];
 | 
			
		||||
    public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
 | 
			
		||||
            @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
 | 
			
		||||
        return lstmLayer(
 | 
			
		||||
                sd.scalar("lstm_max_ts_length", maxTSLength),
 | 
			
		||||
                x, cLast, yLast, weights, config);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
 | 
			
		||||
     */
 | 
			
		||||
    public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
 | 
			
		||||
            @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
 | 
			
		||||
        if(baseName != null) {
 | 
			
		||||
            return lstmLayer(baseName,
 | 
			
		||||
                    sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength),
 | 
			
		||||
                    x, cLast, yLast, weights, config);
 | 
			
		||||
        } else {
 | 
			
		||||
            return lstmLayer(maxTSLength, x, cLast, yLast, weights, config);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The LSTM layer.  Does multiple time steps.
 | 
			
		||||
     *
 | 
			
		||||
     * Input shape depends on data format (in config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, inSize]<br>
 | 
			
		||||
     * NST -> [batchSize, inSize, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, inSize]<br>
 | 
			
		||||
     *
 | 
			
		||||
     * @param baseName  The base name for the lstm layer
 | 
			
		||||
     * @param x         Input, with shape dependent on the data format (in config).
 | 
			
		||||
     * @param cLast     Previous/initial cell state, with shape [batchSize, numUnits]
 | 
			
		||||
     * @param yLast     Previous/initial cell output, with shape [batchSize, numUnits]
 | 
			
		||||
     * @param weights   The layer's weights.
 | 
			
		||||
     * @param config    The layer's config.
 | 
			
		||||
     * @return          The layer's outputs.
 | 
			
		||||
     */
 | 
			
		||||
    public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength,
 | 
			
		||||
            @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
 | 
			
		||||
            @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
 | 
			
		||||
        LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
 | 
			
		||||
        return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}.
 | 
			
		||||
     */
 | 
			
		||||
    public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
 | 
			
		||||
        return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The SRU cell.  Does a single time step operation.
 | 
			
		||||
     *
 | 
			
		||||
     * @param baseName  The base name for the sru cell
 | 
			
		||||
     * @param x         Input, with shape [batchSize, inSize]
 | 
			
		||||
     * @param cLast     Previous cell state, with shape [batchSize, inSize]
 | 
			
		||||
     * @param weights   The cell's weights.
 | 
			
		||||
     * @return          The cell's outputs.
 | 
			
		||||
     */
 | 
			
		||||
    public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
 | 
			
		||||
        return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
 | 
			
		||||
     */
 | 
			
		||||
    public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
 | 
			
		||||
        return sru(x, initialC, null, weights);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
 | 
			
		||||
     */
 | 
			
		||||
    public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
 | 
			
		||||
        return sru(baseName, x, initialC, null, weights);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
 | 
			
		||||
     */
 | 
			
		||||
    public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
 | 
			
		||||
        return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The SRU layer.  Does a single time step operation.
 | 
			
		||||
     *
 | 
			
		||||
     * @param baseName  The base name for the sru layer
 | 
			
		||||
     * @param x         Input, with shape [batchSize, inSize, timeSeriesLength]
 | 
			
		||||
     * @param initialC  Initial cell state, with shape [batchSize, inSize]
 | 
			
		||||
     * @param mask      An optional dropout mask, with shape [batchSize, inSize]
 | 
			
		||||
     * @param weights   The layer's weights.
 | 
			
		||||
     * @return          The layer's outputs.
 | 
			
		||||
     */
 | 
			
		||||
    public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
 | 
			
		||||
        return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
 | 
			
		||||
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import onnx.Onnx;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
import org.tensorflow.framework.NodeDef;
 | 
			
		||||
@ -39,14 +41,15 @@ import java.util.Map;
 | 
			
		||||
 */
 | 
			
		||||
public class GRUCell extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private GRUCellConfiguration configuration;
 | 
			
		||||
    @Getter
 | 
			
		||||
    private GRUWeights weights;
 | 
			
		||||
 | 
			
		||||
    public GRUCell() {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public GRUCell(SameDiff sameDiff, GRUCellConfiguration configuration) {
 | 
			
		||||
        super(null, sameDiff, configuration.args());
 | 
			
		||||
        this.configuration = configuration;
 | 
			
		||||
    public GRUCell(SameDiff sameDiff, SDVariable x, SDVariable hLast, GRUWeights weights) {
 | 
			
		||||
        super(null, sameDiff, weights.argsWithInputs(x, hLast));
 | 
			
		||||
        this.weights = weights;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
 | 
			
		||||
@ -16,12 +16,15 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
 | 
			
		||||
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
 | 
			
		||||
import org.nd4j.linalg.primitives.Pair;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
import org.tensorflow.framework.NodeDef;
 | 
			
		||||
@ -49,10 +52,12 @@ import java.util.Map;
 | 
			
		||||
 * 6: weights - cell peephole (t) connections to output gate, [numUnits]<br>
 | 
			
		||||
 * 7: biases, shape [4*numUnits]<br>
 | 
			
		||||
 * <br>
 | 
			
		||||
 * Input integer arguments: set via {@link LSTMBlockCellConfiguration}<br>
 | 
			
		||||
 * Weights are set via {@link LSTMWeights}.<br>
 | 
			
		||||
 * <br>
 | 
			
		||||
 * Input integer arguments: set via {@link LSTMConfiguration}<br>
 | 
			
		||||
 * 0: if not zero, provide peephole connections<br>
 | 
			
		||||
 * <br>
 | 
			
		||||
 * Input float arguments: set via {@link LSTMBlockCellConfiguration}<br>
 | 
			
		||||
 * Input float arguments: set via {@link LSTMConfiguration}<br>
 | 
			
		||||
 * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
 | 
			
		||||
 * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
 | 
			
		||||
 * <br>
 | 
			
		||||
@ -69,15 +74,19 @@ import java.util.Map;
 | 
			
		||||
 */
 | 
			
		||||
public class LSTMBlockCell extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private LSTMBlockCellConfiguration configuration;
 | 
			
		||||
    private LSTMConfiguration configuration;
 | 
			
		||||
 | 
			
		||||
    @Getter
 | 
			
		||||
    private LSTMWeights weights;
 | 
			
		||||
 | 
			
		||||
    public LSTMBlockCell() {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public LSTMBlockCell(SameDiff sameDiff, LSTMBlockCellConfiguration configuration) {
 | 
			
		||||
        super(null, sameDiff, configuration.args());
 | 
			
		||||
    public LSTMBlockCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
 | 
			
		||||
        super(null, sameDiff, weights.argsWithInputs(x, cLast, yLast));
 | 
			
		||||
        this.configuration = configuration;
 | 
			
		||||
        addIArgument(configuration.iArgs());
 | 
			
		||||
        this.weights = weights;
 | 
			
		||||
        addIArgument(configuration.iArgs(false));
 | 
			
		||||
        addTArgument(configuration.tArgs());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -97,12 +106,12 @@ public class LSTMBlockCell extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
 | 
			
		||||
        configuration = LSTMBlockCellConfiguration.builder()
 | 
			
		||||
        configuration = LSTMConfiguration.builder()
 | 
			
		||||
                .forgetBias(attributesForNode.get("forget_bias").getF())
 | 
			
		||||
                .clippingCellValue(attributesForNode.get("cell_clip").getF())
 | 
			
		||||
                .peepHole(attributesForNode.get("use_peephole").getB())
 | 
			
		||||
                .build();
 | 
			
		||||
        addIArgument(configuration.iArgs());
 | 
			
		||||
        addIArgument(configuration.iArgs(false));
 | 
			
		||||
        addTArgument(configuration.tArgs());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -113,7 +122,7 @@ public class LSTMBlockCell extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> propertiesForFunction() {
 | 
			
		||||
        return configuration.toProperties();
 | 
			
		||||
        return configuration.toProperties(false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
 | 
			
		||||
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
@ -24,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
import org.tensorflow.framework.NodeDef;
 | 
			
		||||
@ -75,13 +77,17 @@ public class LSTMLayer extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private LSTMConfiguration configuration;
 | 
			
		||||
 | 
			
		||||
    @Getter
 | 
			
		||||
    private LSTMWeights weights;
 | 
			
		||||
 | 
			
		||||
    public LSTMLayer() {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public LSTMLayer(@NonNull SameDiff sameDiff, @NonNull LSTMConfiguration configuration) {
 | 
			
		||||
        super(null, sameDiff, configuration.args());
 | 
			
		||||
    public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
 | 
			
		||||
        super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast));
 | 
			
		||||
        this.configuration = configuration;
 | 
			
		||||
        addIArgument(configuration.iArgs());
 | 
			
		||||
        this.weights = weights;
 | 
			
		||||
        addIArgument(configuration.iArgs(true));
 | 
			
		||||
        addTArgument(configuration.tArgs());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -107,7 +113,7 @@ public class LSTMLayer extends DynamicCustomOp {
 | 
			
		||||
                .peepHole(attributesForNode.get("use_peephole").getB())
 | 
			
		||||
                .dataFormat(RnnDataFormat.TNS)  //Always time major for TF BlockLSTM
 | 
			
		||||
                .build();
 | 
			
		||||
        addIArgument(configuration.iArgs());
 | 
			
		||||
        addIArgument(configuration.iArgs(true));
 | 
			
		||||
        addTArgument(configuration.tArgs());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -118,7 +124,7 @@ public class LSTMLayer extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> propertiesForFunction() {
 | 
			
		||||
        return configuration.toProperties();
 | 
			
		||||
        return configuration.toProperties(true);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
 | 
			
		||||
@ -16,11 +16,16 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import onnx.Onnx;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.imports.NoOpNameFoundException;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
import org.tensorflow.framework.NodeDef;
 | 
			
		||||
@ -34,13 +39,18 @@ import java.util.Map;
 | 
			
		||||
 */
 | 
			
		||||
public class SRU extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private SRUConfiguration configuration;
 | 
			
		||||
    @Getter
 | 
			
		||||
    private SRUWeights weights;
 | 
			
		||||
 | 
			
		||||
    @Getter
 | 
			
		||||
    private SDVariable mask;
 | 
			
		||||
 | 
			
		||||
    public SRU() { }
 | 
			
		||||
 | 
			
		||||
    public SRU(SameDiff sameDiff, SRUConfiguration configuration) {
 | 
			
		||||
        super(null, sameDiff, configuration.args());
 | 
			
		||||
        this.configuration = configuration;
 | 
			
		||||
    public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
 | 
			
		||||
        super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask));
 | 
			
		||||
        this.mask = mask;
 | 
			
		||||
        this.weights = weights;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@ -68,6 +78,4 @@ public class SRU extends DynamicCustomOp {
 | 
			
		||||
    public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
 | 
			
		||||
        super.initFromOnnx(node, initWith, attributesForNode, graph);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,17 +16,18 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import onnx.Onnx;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.imports.NoOpNameFoundException;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
import org.tensorflow.framework.NodeDef;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A simple recurrent unit cell.
 | 
			
		||||
 *
 | 
			
		||||
@ -34,14 +35,15 @@ import java.util.Map;
 | 
			
		||||
 */
 | 
			
		||||
public class SRUCell extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private SRUCellConfiguration configuration;
 | 
			
		||||
    @Getter
 | 
			
		||||
    private SRUWeights weights;
 | 
			
		||||
 | 
			
		||||
    public SRUCell() {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SRUCell(SameDiff sameDiff, SRUCellConfiguration configuration) {
 | 
			
		||||
        super(null, sameDiff, configuration.args());
 | 
			
		||||
        this.configuration = configuration;
 | 
			
		||||
    public SRUCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SRUWeights weights) {
 | 
			
		||||
        super(null, sameDiff, weights.argsWithInputs(x, cLast));
 | 
			
		||||
        this.weights = weights;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
 | 
			
		||||
@ -1,57 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.linalg.util.ArrayUtil;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Builder
 | 
			
		||||
@Data
 | 
			
		||||
public class LSTMBlockCellConfiguration {
 | 
			
		||||
 | 
			
		||||
    private boolean peepHole;           //IArg(0)
 | 
			
		||||
    private double forgetBias;          //TArg(0)
 | 
			
		||||
    private double clippingCellValue;   //TArg(1)
 | 
			
		||||
 | 
			
		||||
    private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b;
 | 
			
		||||
 | 
			
		||||
    public Map<String,Object> toProperties()  {
 | 
			
		||||
        Map<String,Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("peepHole",peepHole);
 | 
			
		||||
        ret.put("clippingCellValue",clippingCellValue);
 | 
			
		||||
        ret.put("forgetBias",forgetBias);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable[] args()  {
 | 
			
		||||
        return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public int[] iArgs() {
 | 
			
		||||
        return new int[] {ArrayUtil.fromBoolean(peepHole)};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public double[] tArgs() {
 | 
			
		||||
        return new double[] {forgetBias,clippingCellValue};
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -19,13 +19,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
 | 
			
		||||
import org.nd4j.linalg.util.ArrayUtil;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * LSTM Configuration - for {@link org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer}
 | 
			
		||||
 * LSTM Configuration - for {@link LSTMLayer} and {@link LSTMBlockCell}
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@ -33,29 +35,41 @@ import java.util.Map;
 | 
			
		||||
@Data
 | 
			
		||||
public class LSTMConfiguration {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Whether to provide peephole connections.
 | 
			
		||||
     */
 | 
			
		||||
    private boolean peepHole;           //IArg(0)
 | 
			
		||||
    @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS;  //IArg(1)
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The data format of the input.  Only used in {@link LSTMLayer}, ignored in {@link LSTMBlockCell}.
 | 
			
		||||
     */
 | 
			
		||||
    @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS;  //IArg(1) (only for lstmBlock, not lstmBlockCell)
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training.
 | 
			
		||||
     */
 | 
			
		||||
    private double forgetBias;          //TArg(0)
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Clipping value for cell state, if it is not equal to zero, then cell state is clipped.
 | 
			
		||||
     */
 | 
			
		||||
    private double clippingCellValue;   //TArg(1)
 | 
			
		||||
 | 
			
		||||
    private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b;
 | 
			
		||||
 | 
			
		||||
    public Map<String,Object> toProperties()  {
 | 
			
		||||
    public Map<String,Object> toProperties(boolean includeDataFormat)  {
 | 
			
		||||
        Map<String,Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("peepHole",peepHole);
 | 
			
		||||
        ret.put("clippingCellValue",clippingCellValue);
 | 
			
		||||
        ret.put("forgetBias",forgetBias);
 | 
			
		||||
        ret.put("dataFormat", dataFormat);
 | 
			
		||||
        if(includeDataFormat)
 | 
			
		||||
            ret.put("dataFormat", dataFormat);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable[] args()  {
 | 
			
		||||
        return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public int[] iArgs() {
 | 
			
		||||
        return new int[] {ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()};
 | 
			
		||||
    public int[] iArgs(boolean includeDataFormat) {
 | 
			
		||||
        if(includeDataFormat) {
 | 
			
		||||
            return new int[]{ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()};
 | 
			
		||||
        } else return new int[]{ArrayUtil.fromBoolean(peepHole)};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public double[] tArgs() {
 | 
			
		||||
 | 
			
		||||
@ -1,44 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
public class SRUCellConfiguration {
 | 
			
		||||
    /**
 | 
			
		||||
     *
 | 
			
		||||
     NDArray<T>* xt   = INPUT_VARIABLE(0);               // input [batchSize x inSize], batchSize - batch size, inSize - number of features
 | 
			
		||||
     NDArray<T>* ct_1 = INPUT_VARIABLE(1);               // previous cell state ct  [batchSize x inSize], that is at previous time step t-1
 | 
			
		||||
     NDArray<T>* w    = INPUT_VARIABLE(2);               // weights [inSize x 3*inSize]
 | 
			
		||||
     NDArray<T>* b    = INPUT_VARIABLE(3);               // biases [1 x 2*inSize]
 | 
			
		||||
 | 
			
		||||
     NDArray<T>* ht   = OUTPUT_VARIABLE(0);              // current cell output [batchSize x inSize], that is at current time step t
 | 
			
		||||
     NDArray<T>* ct   = OUTPUT_VARIABLE(1);              // current cell state  [batchSize x inSize], that is at current time step t
 | 
			
		||||
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable xt,ct_1,w,b,h1,ct;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public SDVariable[] args() {
 | 
			
		||||
        return new SDVariable[] {xt,ct_1,w,b,h1,ct};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -1,38 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
public class SRUConfiguration {
 | 
			
		||||
    /**
 | 
			
		||||
     * NDArray<T>* input   = INPUT_VARIABLE(0);                // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
 | 
			
		||||
     NDArray<T>* weights = INPUT_VARIABLE(1);                // W, 2d tensor of weights [3K x K]
 | 
			
		||||
     NDArray<T>* bias    = INPUT_VARIABLE(2);                // B, row of biases with twice length [1 x 2*K]
 | 
			
		||||
     NDArray<T>* init    = INPUT_VARIABLE(3);                // C_{0}, 2d tensor of initial state [bS x K] at time t=0
 | 
			
		||||
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable inputs,weights,bias,init;
 | 
			
		||||
 | 
			
		||||
    public SDVariable[] args() {
 | 
			
		||||
        return new SDVariable[] {inputs,weights,bias,init};
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,62 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The outputs of a GRU cell ({@link GRUCell}.
 | 
			
		||||
 */
 | 
			
		||||
@Getter
 | 
			
		||||
public class GRUCellOutputs {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Reset gate output [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable r;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Update gate output [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable u;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell gate output [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable c;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Current cell output [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable h;
 | 
			
		||||
 | 
			
		||||
    public GRUCellOutputs(SDVariable[] outputs){
 | 
			
		||||
        Preconditions.checkArgument(outputs.length == 4,
 | 
			
		||||
                "Must have 4 GRU cell outputs, got %s", outputs.length);
 | 
			
		||||
 | 
			
		||||
        r = outputs[0];
 | 
			
		||||
        u = outputs[1];
 | 
			
		||||
        c = outputs[2];
 | 
			
		||||
        h = outputs[3];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get all outputs returned by the cell.
 | 
			
		||||
     */
 | 
			
		||||
    public List<SDVariable> getAllOutputs(){
 | 
			
		||||
        return Arrays.asList(r, u, c, h);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get h, the output of the cell.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getOutput(){
 | 
			
		||||
        return h;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,88 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The outputs of a LSTM cell ({@link LSTMBlockCell}.
 | 
			
		||||
 */
 | 
			
		||||
@Getter
 | 
			
		||||
public class LSTMCellOutputs {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - input modulation gate activations [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable i;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Activations, cell state (pre tanh) [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable c;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - forget gate activations [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable f;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - output gate activations [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable o;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - input gate activations [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable z;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell state, post tanh [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable h;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Current cell output [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable y;
 | 
			
		||||
 | 
			
		||||
    public LSTMCellOutputs(SDVariable[] outputs){
 | 
			
		||||
        Preconditions.checkArgument(outputs.length == 7,
 | 
			
		||||
                "Must have 7 LSTM cell outputs, got %s", outputs.length);
 | 
			
		||||
 | 
			
		||||
        i = outputs[0];
 | 
			
		||||
        c = outputs[1];
 | 
			
		||||
        f = outputs[2];
 | 
			
		||||
        o = outputs[3];
 | 
			
		||||
        z = outputs[4];
 | 
			
		||||
        h = outputs[5];
 | 
			
		||||
        y = outputs[6];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get all outputs returned by the cell.
 | 
			
		||||
     */
 | 
			
		||||
    public List<SDVariable> getAllOutputs(){
 | 
			
		||||
        return Arrays.asList(i, c, f, o, z, h, y);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get y, the output of the cell.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getOutput(){
 | 
			
		||||
        return y;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get c, the cell's state.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getState(){
 | 
			
		||||
        return c;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,180 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import lombok.AccessLevel;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDIndex;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The outputs of a LSTM layer ({@link LSTMLayer}.
 | 
			
		||||
 */
 | 
			
		||||
@Getter
 | 
			
		||||
public class LSTMLayerOutputs {
 | 
			
		||||
 | 
			
		||||
    private RnnDataFormat dataFormat;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - input modulation gate activations.
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable i;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Activations, cell state (pre tanh).
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable c;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - forget gate activations.
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable f;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - output gate activations.
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable o;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Output - input gate activations.
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable z;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell state, post tanh.
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable h;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Current cell output.
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable y;
 | 
			
		||||
 | 
			
		||||
    public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){
 | 
			
		||||
        Preconditions.checkArgument(outputs.length == 7,
 | 
			
		||||
                "Must have 7 LSTM layer outputs, got %s", outputs.length);
 | 
			
		||||
 | 
			
		||||
        i = outputs[0];
 | 
			
		||||
        c = outputs[1];
 | 
			
		||||
        f = outputs[2];
 | 
			
		||||
        o = outputs[3];
 | 
			
		||||
        z = outputs[4];
 | 
			
		||||
        h = outputs[5];
 | 
			
		||||
        y = outputs[6];
 | 
			
		||||
        this.dataFormat = dataFormat;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get all outputs returned by the cell.
 | 
			
		||||
     */
 | 
			
		||||
    public List<SDVariable> getAllOutputs(){
 | 
			
		||||
        return Arrays.asList(i, c, f, o, z, h, y);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get y, the output of the cell for all time steps.
 | 
			
		||||
     * 
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getOutput(){
 | 
			
		||||
        return y;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get c, the cell's state for all time steps.
 | 
			
		||||
     *
 | 
			
		||||
     * Shape depends on data format (in layer config):<br>
 | 
			
		||||
     * TNS -> [timeSteps, batchSize, numUnits]<br>
 | 
			
		||||
     * NST -> [batchSize, numUnits, timeSteps]<br>
 | 
			
		||||
     * NTS -> [batchSize, timeSteps, numUnits]<br>
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getState(){
 | 
			
		||||
        return c;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private SDVariable lastOutput = null;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get y, the output of the cell, for the last time step.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getLastOutput(){
 | 
			
		||||
        if(lastOutput != null)
 | 
			
		||||
            return lastOutput;
 | 
			
		||||
 | 
			
		||||
        switch (dataFormat){
 | 
			
		||||
            case TNS:
 | 
			
		||||
                lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
 | 
			
		||||
                break;
 | 
			
		||||
            case NST:
 | 
			
		||||
                lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
 | 
			
		||||
                break;
 | 
			
		||||
            case NTS:
 | 
			
		||||
                lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
 | 
			
		||||
                break;
 | 
			
		||||
        }
 | 
			
		||||
        return lastOutput;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private SDVariable lastState = null;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get c, the state of the cell, for the last time step.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getLastState(){
 | 
			
		||||
        if(lastState != null)
 | 
			
		||||
            return lastState;
 | 
			
		||||
 | 
			
		||||
        switch (dataFormat){
 | 
			
		||||
            case TNS:
 | 
			
		||||
                lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
 | 
			
		||||
                break;
 | 
			
		||||
            case NST:
 | 
			
		||||
                lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
 | 
			
		||||
                break;
 | 
			
		||||
            case NTS:
 | 
			
		||||
                lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
 | 
			
		||||
                break;
 | 
			
		||||
        }
 | 
			
		||||
        return lastState;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,60 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The outputs of a GRU cell ({@link GRUCell}.
 | 
			
		||||
 */
 | 
			
		||||
@Getter
 | 
			
		||||
public class SRUCellOutputs {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Current cell output [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable h;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Current cell state [batchSize, numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable c;
 | 
			
		||||
 | 
			
		||||
    public SRUCellOutputs(SDVariable[] outputs){
 | 
			
		||||
        Preconditions.checkArgument(outputs.length == 2,
 | 
			
		||||
                "Must have 2 SRU cell outputs, got %s", outputs.length);
 | 
			
		||||
 | 
			
		||||
        h = outputs[0];
 | 
			
		||||
        c = outputs[1];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get all outputs returned by the cell.
 | 
			
		||||
     */
 | 
			
		||||
    public List<SDVariable> getAllOutputs(){
 | 
			
		||||
        return Arrays.asList(h, c);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get h, the output of the cell.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, inSize].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getOutput(){
 | 
			
		||||
        return h;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get c, the state of the cell.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, inSize].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getState(){
 | 
			
		||||
        return c;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,92 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import lombok.AccessLevel;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDIndex;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The outputs of a GRU cell ({@link GRUCell}.
 | 
			
		||||
 */
 | 
			
		||||
@Getter
 | 
			
		||||
public class SRULayerOutputs {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Current cell output [batchSize, inSize, timeSeriesLength].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable h;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Current cell state [batchSize, inSize, timeSeriesLength].
 | 
			
		||||
     */
 | 
			
		||||
    private SDVariable c;
 | 
			
		||||
 | 
			
		||||
    public SRULayerOutputs(SDVariable[] outputs){
 | 
			
		||||
        Preconditions.checkArgument(outputs.length == 2,
 | 
			
		||||
                "Must have 2 SRU cell outputs, got %s", outputs.length);
 | 
			
		||||
 | 
			
		||||
        h = outputs[0];
 | 
			
		||||
        c = outputs[1];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get all outputs returned by the cell.
 | 
			
		||||
     */
 | 
			
		||||
    public List<SDVariable> getAllOutputs(){
 | 
			
		||||
        return Arrays.asList(h, c);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get h, the output of the cell.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, inSize, timeSeriesLength].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getOutput(){
 | 
			
		||||
        return h;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get c, the state of the cell.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, inSize, timeSeriesLength].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getState(){
 | 
			
		||||
        return c;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private SDVariable lastOutput = null;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get y, the output of the cell, for the last time step.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, inSize].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getLastOutput(){
 | 
			
		||||
        if(lastOutput != null)
 | 
			
		||||
            return lastOutput;
 | 
			
		||||
 | 
			
		||||
        lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
 | 
			
		||||
        return lastOutput;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private SDVariable lastState = null;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get c, the state of the cell, for the last time step.
 | 
			
		||||
     *
 | 
			
		||||
     * Has shape [batchSize, inSize].
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable getLastState(){
 | 
			
		||||
        if(lastState != null)
 | 
			
		||||
            return lastState;
 | 
			
		||||
 | 
			
		||||
        lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
 | 
			
		||||
        return lastState;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,51 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.EqualsAndHashCode;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The weight configuration of a GRU cell.  For {@link GRUCell}.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
@EqualsAndHashCode(callSuper = true)
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
public class GRUWeights extends RNNWeights {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Reset and Update gate weights, with a shape of [inSize + numUnits, 2*numUnits].
 | 
			
		||||
     *
 | 
			
		||||
     * The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset.
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable ruWeight;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell gate weights, with a shape of [inSize + numUnits, numUnits]
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable cWeight;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Reset and Update gate bias, with a shape of [2*numUnits].  May be null.
 | 
			
		||||
     *
 | 
			
		||||
     * The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset.
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable ruBias;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell gate bias, with a shape of [numUnits].  May be null.
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable cBias;
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public SDVariable[] args() {
 | 
			
		||||
        return filterNonNull(ruWeight, cWeight, ruBias, cBias);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,57 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.EqualsAndHashCode;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The weight configuration of a LSTM layer.  For {@link LSTMLayer} and {@link LSTMBlockCell}.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
@EqualsAndHashCode(callSuper = true)
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
public class LSTMWeights extends RNNWeights {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Input to hidden weights and hidden to hidden weights, with a shape of [inSize + numUnits, 4*numUnits].
 | 
			
		||||
     *
 | 
			
		||||
     * Input to hidden and hidden to hidden are concatenated in dimension 0,
 | 
			
		||||
     * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :].
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable weights;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable inputPeepholeWeights;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell peephole (t-1) connections to forget gate, with a shape of [numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable forgetPeepholeWeights;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Cell peephole (t) connections to output gate, with a shape of [numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable outputPeepholeWeights;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits].
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable bias;
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public SDVariable[] args() {
 | 
			
		||||
        return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,35 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.linalg.util.ArrayUtil;
 | 
			
		||||
 | 
			
		||||
public abstract class RNNWeights {
 | 
			
		||||
    public abstract SDVariable[] args();
 | 
			
		||||
 | 
			
		||||
    protected static SDVariable[] filterNonNull(SDVariable... args){
 | 
			
		||||
        int count = 0;
 | 
			
		||||
        for(SDVariable v : args){
 | 
			
		||||
            if(v != null){
 | 
			
		||||
                count++;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        SDVariable[] res = new SDVariable[count];
 | 
			
		||||
 | 
			
		||||
        int i = 0;
 | 
			
		||||
 | 
			
		||||
        for(SDVariable v : args){
 | 
			
		||||
            if(v != null){
 | 
			
		||||
                res[i] = v;
 | 
			
		||||
                i++;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable[] argsWithInputs(SDVariable... inputs){
 | 
			
		||||
        return ArrayUtil.combine(inputs, args());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,37 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.EqualsAndHashCode;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The weight configuration of a SRU layer.  For {@link SRU} and {@link SRUCell}.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
@EqualsAndHashCode(callSuper = true)
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
public class SRUWeights extends RNNWeights {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Weights, with shape [inSize, 3*inSize].
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable weights;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Biases, with shape [2*inSize].
 | 
			
		||||
     */
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private SDVariable bias;
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public SDVariable[] args() {
 | 
			
		||||
        return new SDVariable[]{weights, bias};
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -16,14 +16,19 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.autodiff.opvalidation;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDIndex;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4jBackend;
 | 
			
		||||
import org.nd4j.linalg.indexing.NDArrayIndex;
 | 
			
		||||
@ -59,23 +64,18 @@ public class RnnOpValidation extends BaseOpValidation {
 | 
			
		||||
        SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut));
 | 
			
		||||
 | 
			
		||||
        double fb = 1.0;
 | 
			
		||||
        LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder()
 | 
			
		||||
                .xt(x)
 | 
			
		||||
                .cLast(cLast)
 | 
			
		||||
                .yLast(yLast)
 | 
			
		||||
                .W(W)
 | 
			
		||||
                .Wci(Wci)
 | 
			
		||||
                .Wcf(Wcf)
 | 
			
		||||
                .Wco(Wco)
 | 
			
		||||
                .b(b)
 | 
			
		||||
        LSTMConfiguration conf = LSTMConfiguration.builder()
 | 
			
		||||
                .peepHole(true)
 | 
			
		||||
                .forgetBias(fb)
 | 
			
		||||
                .clippingCellValue(0.0)
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        List<SDVariable> v = sd.rnn().lstmBlockCell("lstm", conf);  //Output order: i, c, f, o, z, h, y
 | 
			
		||||
        LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
 | 
			
		||||
                .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
 | 
			
		||||
 | 
			
		||||
        LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf);  //Output order: i, c, f, o, z, h, y
 | 
			
		||||
        List<String> toExec = new ArrayList<>();
 | 
			
		||||
        for(SDVariable sdv : v){
 | 
			
		||||
        for(SDVariable sdv : v.getAllOutputs()){
 | 
			
		||||
            toExec.add(sdv.getVarName());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -167,23 +167,18 @@ public class RnnOpValidation extends BaseOpValidation {
 | 
			
		||||
        SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8));
 | 
			
		||||
 | 
			
		||||
        double fb = 1.0;
 | 
			
		||||
        LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder()
 | 
			
		||||
                .xt(x)
 | 
			
		||||
                .cLast(cLast)
 | 
			
		||||
                .yLast(yLast)
 | 
			
		||||
                .W(W)
 | 
			
		||||
                .Wci(Wci)
 | 
			
		||||
                .Wcf(Wcf)
 | 
			
		||||
                .Wco(Wco)
 | 
			
		||||
                .b(b)
 | 
			
		||||
        LSTMConfiguration conf = LSTMConfiguration.builder()
 | 
			
		||||
                .peepHole(false)
 | 
			
		||||
                .forgetBias(fb)
 | 
			
		||||
                .clippingCellValue(0.0)
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        List<SDVariable> v = sd.rnn().lstmBlockCell("lstm", conf);  //Output order: i, c, f, o, z, h, y
 | 
			
		||||
        LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
 | 
			
		||||
                .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
 | 
			
		||||
 | 
			
		||||
        LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf);  //Output order: i, c, f, o, z, h, y
 | 
			
		||||
        List<String> toExec = new ArrayList<>();
 | 
			
		||||
        for(SDVariable sdv : v){
 | 
			
		||||
        for(SDVariable sdv : v.getAllOutputs()){
 | 
			
		||||
            toExec.add(sdv.getVarName());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -228,16 +223,14 @@ public class RnnOpValidation extends BaseOpValidation {
 | 
			
		||||
        SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
 | 
			
		||||
 | 
			
		||||
        double fb = 1.0;
 | 
			
		||||
        GRUCellConfiguration conf = GRUCellConfiguration.builder()
 | 
			
		||||
                .xt(x)
 | 
			
		||||
                .hLast(hLast)
 | 
			
		||||
                .Wru(Wru)
 | 
			
		||||
                .Wc(Wc)
 | 
			
		||||
                .bru(bru)
 | 
			
		||||
                .bc(bc)
 | 
			
		||||
        GRUWeights weights = GRUWeights.builder()
 | 
			
		||||
                .ruWeight(Wru)
 | 
			
		||||
                .cWeight(Wc)
 | 
			
		||||
                .ruBias(bru)
 | 
			
		||||
                .cBias(bc)
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        List<SDVariable> v = sd.rnn().gru("gru", conf);
 | 
			
		||||
        List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs();
 | 
			
		||||
        List<String> toExec = new ArrayList<>();
 | 
			
		||||
        for(SDVariable sdv : v){
 | 
			
		||||
            toExec.add(sdv.getVarName());
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,7 @@ import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.io.Serializable;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Simple pair implementation
 | 
			
		||||
@ -86,4 +87,10 @@ public class Pair<K, V> implements Serializable {
 | 
			
		||||
    public static <T, E> Pair<T,E> pairOf(T key, E value) {
 | 
			
		||||
        return new Pair<T, E>(key, value);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static <T> Pair<T, T> fromArray(T[] arr){
 | 
			
		||||
        Preconditions.checkArgument(arr.length == 2,
 | 
			
		||||
                "Can only create a pair from an array with two values, got %s", arr.length);
 | 
			
		||||
        return new Pair<>(arr[0], arr[1]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user