cleanup SDRNN and rnn ops (#238)

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-09-04 19:25:03 -07:00 committed by Alex Black
parent 7d85775934
commit 79867f5c5a
23 changed files with 943 additions and 277 deletions

View File

@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1 auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
auto b = INPUT_VARIABLE(3); // biases [1 x 2*inSize] auto b = INPUT_VARIABLE(3); // biases [2*inSize]
auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t

View File

@ -6511,4 +6511,22 @@ public class SameDiff extends SDBaseOps {
public String generateNewVarName(String base, int argIndex) { public String generateNewVarName(String base, int argIndex) {
return generateNewVarName(base, argIndex, true); return generateNewVarName(base, argIndex, true);
} }
/**
* Returns an unused variable name of the format &lt;base&gt;_#.
*
* Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}.
*/
public String generateDistinctCustomVariableName(String base){
if(!variables.containsKey(base))
return base;
int inc = 1;
while(variables.containsKey(base + "_" + inc)){
inc++;
}
return base + "_" + inc;
}
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*; import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
@ -23,6 +24,15 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.nd4j.linalg.primitives.Pair;
/** /**
* SameDiff Recurrent Neural Network operations<br> * SameDiff Recurrent Neural Network operations<br>
@ -39,90 +49,163 @@ public class SDRNN extends SDOps {
/** /**
* The gru cell * See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}.
*
* @param configuration the configuration to use
* @return
*/ */
public List<SDVariable> gru(GRUCellConfiguration configuration) { public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
GRUCell c = new GRUCell(sd, configuration); GRUCell c = new GRUCell(sd, x, hLast, weights);
return Arrays.asList(c.outputVariables()); return new GRUCellOutputs(c.outputVariables());
} }
/** /**
* The gru cell * The GRU cell. Does a single time step operation.
* *
* @param baseName the base name for the gru cell * @param baseName The base name for the gru cell
* @param configuration the configuration to use * @param x Input, with shape [batchSize, inSize]
* @return * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits]
* @param weights The cell's weights.
* @return The cell's outputs.
*/ */
public List<SDVariable> gru(String baseName, GRUCellConfiguration configuration) { public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
GRUCell c = new GRUCell(sd, configuration); GRUCell c = new GRUCell(sd, x, hLast, weights);
return Arrays.asList(c.outputVariables(baseName)); return new GRUCellOutputs(c.outputVariables(baseName));
}
/**
* LSTM unit
*
* @param baseName the base name for outputs
* @param configuration the configuration to use
* @return
*/
public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) {
return new LSTMCell(sd, configuration).outputVariables(baseName)[0];
}
public List<SDVariable> lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){
SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name);
return Arrays.asList(v);
}
public List<SDVariable> lstmLayer(String name, LSTMConfiguration configuration){
SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name);
return Arrays.asList(v);
} }
/** /**
* Simple recurrent unit * See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}.
*
* @param configuration the configuration for the sru
* @return
*/ */
public SDVariable sru(SRUConfiguration configuration) { public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
return new SRU(sd, configuration).outputVariables()[0]; LSTMWeights weights, LSTMConfiguration config){
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
return new LSTMCellOutputs(c.outputVariables());
} }
/** /**
* Simiple recurrent unit * The LSTM cell. Does a single time step operation.
* *
* @param baseName the base name to use for output variables * @param baseName The base name for the lstm cell
* @param configuration the configuration for the sru * @param x Input, with shape [batchSize, inSize]
* @return * @param cLast Previous cell state, with shape [batchSize, numUnits]
* @param yLast Previous cell output, with shape [batchSize, numUnits]
* @param weights The cell's weights.
* @param config The cell's config.
* @return The cell's outputs.
*/ */
public SDVariable sru(String baseName, SRUConfiguration configuration) { public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
return new SRU(sd, configuration).outputVariables(baseName)[0]; @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
return new LSTMCellOutputs(c.outputVariables(baseName));
} }
/** /**
* An sru cell * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*
* @param configuration the configuration for the sru cell
* @return
*/ */
public SDVariable sruCell(SRUCellConfiguration configuration) { public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength,
return new SRUCell(sd, configuration).outputVariables()[0]; @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat());
} }
/** /**
* An sru cell * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*
* @param baseName the base name to use for the output variables
* @param configuration the configuration for the sru cell
* @return
*/ */
public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) { public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
return new SRUCell(sd, configuration).outputVariables(baseName)[0]; @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
return lstmLayer(
sd.scalar("lstm_max_ts_length", maxTSLength),
x, cLast, yLast, weights, config);
}
/**
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*/
public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
if(baseName != null) {
return lstmLayer(baseName,
sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength),
x, cLast, yLast, weights, config);
} else {
return lstmLayer(maxTSLength, x, cLast, yLast, weights, config);
}
}
/**
* The LSTM layer. Does multiple time steps.
*
* Input shape depends on data format (in config):<br>
* TNS -> [timeSteps, batchSize, inSize]<br>
* NST -> [batchSize, inSize, timeSteps]<br>
* NTS -> [batchSize, timeSteps, inSize]<br>
*
* @param baseName The base name for the lstm layer
* @param x Input, with shape dependent on the data format (in config).
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits]
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits]
* @param weights The layer's weights.
* @param config The layer's config.
* @return The layer's outputs.
*/
public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength,
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
}
/**
* See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}.
*/
public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables());
}
/**
* The SRU cell. Does a single time step operation.
*
* @param baseName The base name for the sru cell
* @param x Input, with shape [batchSize, inSize]
* @param cLast Previous cell state, with shape [batchSize, inSize]
* @param weights The cell's weights.
* @return The cell's outputs.
*/
public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName));
}
/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
return sru(x, initialC, null, weights);
}
/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
return sru(baseName, x, initialC, null, weights);
}
/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables());
}
/**
* The SRU layer. Does a single time step operation.
*
* @param baseName The base name for the sru layer
* @param x Input, with shape [batchSize, inSize, timeSeriesLength]
* @param initialC Initial cell state, with shape [batchSize, inSize]
* @param mask An optional dropout mask, with shape [batchSize, inSize]
* @param weights The layer's weights.
* @return The layer's outputs.
*/
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName));
} }
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.Getter;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -39,14 +41,15 @@ import java.util.Map;
*/ */
public class GRUCell extends DynamicCustomOp { public class GRUCell extends DynamicCustomOp {
private GRUCellConfiguration configuration; @Getter
private GRUWeights weights;
public GRUCell() { public GRUCell() {
} }
public GRUCell(SameDiff sameDiff, GRUCellConfiguration configuration) { public GRUCell(SameDiff sameDiff, SDVariable x, SDVariable hLast, GRUWeights weights) {
super(null, sameDiff, configuration.args()); super(null, sameDiff, weights.argsWithInputs(x, hLast));
this.configuration = configuration; this.weights = weights;
} }
@Override @Override

View File

@ -16,12 +16,15 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.primitives.Pair;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -49,10 +52,12 @@ import java.util.Map;
* 6: weights - cell peephole (t) connections to output gate, [numUnits]<br> * 6: weights - cell peephole (t) connections to output gate, [numUnits]<br>
* 7: biases, shape [4*numUnits]<br> * 7: biases, shape [4*numUnits]<br>
* <br> * <br>
* Input integer arguments: set via {@link LSTMBlockCellConfiguration}<br> * Weights are set via {@link LSTMWeights}.<br>
* <br>
* Input integer arguments: set via {@link LSTMConfiguration}<br>
* 0: if not zero, provide peephole connections<br> * 0: if not zero, provide peephole connections<br>
* <br> * <br>
* Input float arguments: set via {@link LSTMBlockCellConfiguration}<br> * Input float arguments: set via {@link LSTMConfiguration}<br>
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br> * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br> * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
* <br> * <br>
@ -69,15 +74,19 @@ import java.util.Map;
*/ */
public class LSTMBlockCell extends DynamicCustomOp { public class LSTMBlockCell extends DynamicCustomOp {
private LSTMBlockCellConfiguration configuration; private LSTMConfiguration configuration;
@Getter
private LSTMWeights weights;
public LSTMBlockCell() { public LSTMBlockCell() {
} }
public LSTMBlockCell(SameDiff sameDiff, LSTMBlockCellConfiguration configuration) { public LSTMBlockCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
super(null, sameDiff, configuration.args()); super(null, sameDiff, weights.argsWithInputs(x, cLast, yLast));
this.configuration = configuration; this.configuration = configuration;
addIArgument(configuration.iArgs()); this.weights = weights;
addIArgument(configuration.iArgs(false));
addTArgument(configuration.tArgs()); addTArgument(configuration.tArgs());
} }
@ -97,12 +106,12 @@ public class LSTMBlockCell extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
configuration = LSTMBlockCellConfiguration.builder() configuration = LSTMConfiguration.builder()
.forgetBias(attributesForNode.get("forget_bias").getF()) .forgetBias(attributesForNode.get("forget_bias").getF())
.clippingCellValue(attributesForNode.get("cell_clip").getF()) .clippingCellValue(attributesForNode.get("cell_clip").getF())
.peepHole(attributesForNode.get("use_peephole").getB()) .peepHole(attributesForNode.get("use_peephole").getB())
.build(); .build();
addIArgument(configuration.iArgs()); addIArgument(configuration.iArgs(false));
addTArgument(configuration.tArgs()); addTArgument(configuration.tArgs());
} }
@ -113,7 +122,7 @@ public class LSTMBlockCell extends DynamicCustomOp {
@Override @Override
public Map<String, Object> propertiesForFunction() { public Map<String, Object> propertiesForFunction() {
return configuration.toProperties(); return configuration.toProperties(false);
} }
@Override @Override

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -24,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -75,13 +77,17 @@ public class LSTMLayer extends DynamicCustomOp {
private LSTMConfiguration configuration; private LSTMConfiguration configuration;
@Getter
private LSTMWeights weights;
public LSTMLayer() { public LSTMLayer() {
} }
public LSTMLayer(@NonNull SameDiff sameDiff, @NonNull LSTMConfiguration configuration) { public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
super(null, sameDiff, configuration.args()); super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast));
this.configuration = configuration; this.configuration = configuration;
addIArgument(configuration.iArgs()); this.weights = weights;
addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs()); addTArgument(configuration.tArgs());
} }
@ -107,7 +113,7 @@ public class LSTMLayer extends DynamicCustomOp {
.peepHole(attributesForNode.get("use_peephole").getB()) .peepHole(attributesForNode.get("use_peephole").getB())
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM .dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
.build(); .build();
addIArgument(configuration.iArgs()); addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs()); addTArgument(configuration.tArgs());
} }
@ -118,7 +124,7 @@ public class LSTMLayer extends DynamicCustomOp {
@Override @Override
public Map<String, Object> propertiesForFunction() { public Map<String, Object> propertiesForFunction() {
return configuration.toProperties(); return configuration.toProperties(true);
} }
@Override @Override

View File

@ -16,11 +16,16 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import java.util.Arrays;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -34,13 +39,18 @@ import java.util.Map;
*/ */
public class SRU extends DynamicCustomOp { public class SRU extends DynamicCustomOp {
private SRUConfiguration configuration; @Getter
private SRUWeights weights;
@Getter
private SDVariable mask;
public SRU() { } public SRU() { }
public SRU(SameDiff sameDiff, SRUConfiguration configuration) { public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
super(null, sameDiff, configuration.args()); super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask));
this.configuration = configuration; this.mask = mask;
this.weights = weights;
} }
@Override @Override
@ -68,6 +78,4 @@ public class SRU extends DynamicCustomOp {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }
} }

View File

@ -16,17 +16,18 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import java.util.Map;
import lombok.Getter;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
import java.util.Map;
/** /**
* A simple recurrent unit cell. * A simple recurrent unit cell.
* *
@ -34,14 +35,15 @@ import java.util.Map;
*/ */
public class SRUCell extends DynamicCustomOp { public class SRUCell extends DynamicCustomOp {
private SRUCellConfiguration configuration; @Getter
private SRUWeights weights;
public SRUCell() { public SRUCell() {
} }
public SRUCell(SameDiff sameDiff, SRUCellConfiguration configuration) { public SRUCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SRUWeights weights) {
super(null, sameDiff, configuration.args()); super(null, sameDiff, weights.argsWithInputs(x, cLast));
this.configuration = configuration; this.weights = weights;
} }
@Override @Override

View File

@ -1,57 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
import lombok.Builder;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.LinkedHashMap;
import java.util.Map;
@Builder
@Data
public class LSTMBlockCellConfiguration {
private boolean peepHole; //IArg(0)
private double forgetBias; //TArg(0)
private double clippingCellValue; //TArg(1)
private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b;
public Map<String,Object> toProperties() {
Map<String,Object> ret = new LinkedHashMap<>();
ret.put("peepHole",peepHole);
ret.put("clippingCellValue",clippingCellValue);
ret.put("forgetBias",forgetBias);
return ret;
}
public SDVariable[] args() {
return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
}
public int[] iArgs() {
return new int[] {ArrayUtil.fromBoolean(peepHole)};
}
public double[] tArgs() {
return new double[] {forgetBias,clippingCellValue};
}
}

View File

@ -19,13 +19,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
/** /**
* LSTM Configuration - for {@link org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer} * LSTM Configuration - for {@link LSTMLayer} and {@link LSTMBlockCell}
* *
* @author Alex Black * @author Alex Black
*/ */
@ -33,29 +35,41 @@ import java.util.Map;
@Data @Data
public class LSTMConfiguration { public class LSTMConfiguration {
/**
* Whether to provide peephole connections.
*/
private boolean peepHole; //IArg(0) private boolean peepHole; //IArg(0)
@Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1)
/**
* The data format of the input. Only used in {@link LSTMLayer}, ignored in {@link LSTMBlockCell}.
*/
@Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1) (only for lstmBlock, not lstmBlockCell)
/**
* The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training.
*/
private double forgetBias; //TArg(0) private double forgetBias; //TArg(0)
/**
* Clipping value for cell state, if it is not equal to zero, then cell state is clipped.
*/
private double clippingCellValue; //TArg(1) private double clippingCellValue; //TArg(1)
private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b; public Map<String,Object> toProperties(boolean includeDataFormat) {
public Map<String,Object> toProperties() {
Map<String,Object> ret = new LinkedHashMap<>(); Map<String,Object> ret = new LinkedHashMap<>();
ret.put("peepHole",peepHole); ret.put("peepHole",peepHole);
ret.put("clippingCellValue",clippingCellValue); ret.put("clippingCellValue",clippingCellValue);
ret.put("forgetBias",forgetBias); ret.put("forgetBias",forgetBias);
ret.put("dataFormat", dataFormat); if(includeDataFormat)
ret.put("dataFormat", dataFormat);
return ret; return ret;
} }
public SDVariable[] args() {
return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
}
public int[] iArgs(boolean includeDataFormat) {
public int[] iArgs() { if(includeDataFormat) {
return new int[] {ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()}; return new int[]{ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()};
} else return new int[]{ArrayUtil.fromBoolean(peepHole)};
} }
public double[] tArgs() { public double[] tArgs() {

View File

@ -1,44 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
import lombok.Builder;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
@Data
@Builder
public class SRUCellConfiguration {
/**
*
NDArray<T>* xt = INPUT_VARIABLE(0); // input [batchSize x inSize], batchSize - batch size, inSize - number of features
NDArray<T>* ct_1 = INPUT_VARIABLE(1); // previous cell state ct [batchSize x inSize], that is at previous time step t-1
NDArray<T>* w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
NDArray<T>* b = INPUT_VARIABLE(3); // biases [1 x 2*inSize]
NDArray<T>* ht = OUTPUT_VARIABLE(0); // current cell output [batchSize x inSize], that is at current time step t
NDArray<T>* ct = OUTPUT_VARIABLE(1); // current cell state [batchSize x inSize], that is at current time step t
*/
private SDVariable xt,ct_1,w,b,h1,ct;
public SDVariable[] args() {
return new SDVariable[] {xt,ct_1,w,b,h1,ct};
}
}

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
import lombok.Builder;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
@Data
@Builder
public class SRUConfiguration {
/**
* NDArray<T>* input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
NDArray<T>* weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
NDArray<T>* bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K]
NDArray<T>* init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
*/
private SDVariable inputs,weights,bias,init;
public SDVariable[] args() {
return new SDVariable[] {inputs,weights,bias,init};
}
}

View File

@ -0,0 +1,62 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
import java.util.Arrays;
import java.util.List;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
/**
* The outputs of a GRU cell ({@link GRUCell}.
*/
@Getter
public class GRUCellOutputs {
/**
* Reset gate output [batchSize, numUnits].
*/
private SDVariable r;
/**
* Update gate output [batchSize, numUnits].
*/
private SDVariable u;
/**
* Cell gate output [batchSize, numUnits].
*/
private SDVariable c;
/**
* Current cell output [batchSize, numUnits].
*/
private SDVariable h;
public GRUCellOutputs(SDVariable[] outputs){
Preconditions.checkArgument(outputs.length == 4,
"Must have 4 GRU cell outputs, got %s", outputs.length);
r = outputs[0];
u = outputs[1];
c = outputs[2];
h = outputs[3];
}
/**
* Get all outputs returned by the cell.
*/
public List<SDVariable> getAllOutputs(){
return Arrays.asList(r, u, c, h);
}
/**
* Get h, the output of the cell.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getOutput(){
return h;
}
}

View File

@ -0,0 +1,88 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
import java.util.Arrays;
import java.util.List;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
/**
* The outputs of a LSTM cell ({@link LSTMBlockCell}.
*/
@Getter
public class LSTMCellOutputs {
/**
* Output - input modulation gate activations [batchSize, numUnits].
*/
private SDVariable i;
/**
* Activations, cell state (pre tanh) [batchSize, numUnits].
*/
private SDVariable c;
/**
* Output - forget gate activations [batchSize, numUnits].
*/
private SDVariable f;
/**
* Output - output gate activations [batchSize, numUnits].
*/
private SDVariable o;
/**
* Output - input gate activations [batchSize, numUnits].
*/
private SDVariable z;
/**
* Cell state, post tanh [batchSize, numUnits].
*/
private SDVariable h;
/**
* Current cell output [batchSize, numUnits].
*/
private SDVariable y;
public LSTMCellOutputs(SDVariable[] outputs){
Preconditions.checkArgument(outputs.length == 7,
"Must have 7 LSTM cell outputs, got %s", outputs.length);
i = outputs[0];
c = outputs[1];
f = outputs[2];
o = outputs[3];
z = outputs[4];
h = outputs[5];
y = outputs[6];
}
/**
* Get all outputs returned by the cell.
*/
public List<SDVariable> getAllOutputs(){
return Arrays.asList(i, c, f, o, z, h, y);
}
/**
* Get y, the output of the cell.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getOutput(){
return y;
}
/**
* Get c, the cell's state.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getState(){
return c;
}
}

View File

@ -0,0 +1,180 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
import java.util.Arrays;
import java.util.List;
import lombok.AccessLevel;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
/**
* The outputs of a LSTM layer ({@link LSTMLayer}.
*/
@Getter
public class LSTMLayerOutputs {
private RnnDataFormat dataFormat;
/**
* Output - input modulation gate activations.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable i;
/**
* Activations, cell state (pre tanh).
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable c;
/**
* Output - forget gate activations.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable f;
/**
* Output - output gate activations.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable o;
/**
* Output - input gate activations.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable z;
/**
* Cell state, post tanh.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable h;
/**
* Current cell output.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable y;
public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){
Preconditions.checkArgument(outputs.length == 7,
"Must have 7 LSTM layer outputs, got %s", outputs.length);
i = outputs[0];
c = outputs[1];
f = outputs[2];
o = outputs[3];
z = outputs[4];
h = outputs[5];
y = outputs[6];
this.dataFormat = dataFormat;
}
/**
* Get all outputs returned by the cell.
*/
public List<SDVariable> getAllOutputs(){
return Arrays.asList(i, c, f, o, z, h, y);
}
/**
* Get y, the output of the cell for all time steps.
*
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
public SDVariable getOutput(){
return y;
}
/**
* Get c, the cell's state for all time steps.
*
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
public SDVariable getState(){
return c;
}
private SDVariable lastOutput = null;
/**
* Get y, the output of the cell, for the last time step.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getLastOutput(){
if(lastOutput != null)
return lastOutput;
switch (dataFormat){
case TNS:
lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
break;
case NST:
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
break;
case NTS:
lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
break;
}
return lastOutput;
}
private SDVariable lastState = null;
/**
* Get c, the state of the cell, for the last time step.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getLastState(){
if(lastState != null)
return lastState;
switch (dataFormat){
case TNS:
lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
break;
case NST:
lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
break;
case NTS:
lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
break;
}
return lastState;
}
}

View File

@ -0,0 +1,60 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
import java.util.Arrays;
import java.util.List;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
/**
* The outputs of a GRU cell ({@link GRUCell}.
*/
@Getter
public class SRUCellOutputs {
/**
* Current cell output [batchSize, numUnits].
*/
private SDVariable h;
/**
* Current cell state [batchSize, numUnits].
*/
private SDVariable c;
public SRUCellOutputs(SDVariable[] outputs){
Preconditions.checkArgument(outputs.length == 2,
"Must have 2 SRU cell outputs, got %s", outputs.length);
h = outputs[0];
c = outputs[1];
}
/**
* Get all outputs returned by the cell.
*/
public List<SDVariable> getAllOutputs(){
return Arrays.asList(h, c);
}
/**
* Get h, the output of the cell.
*
* Has shape [batchSize, inSize].
*/
public SDVariable getOutput(){
return h;
}
/**
* Get c, the state of the cell.
*
* Has shape [batchSize, inSize].
*/
public SDVariable getState(){
return c;
}
}

View File

@ -0,0 +1,92 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
import java.util.Arrays;
import java.util.List;
import lombok.AccessLevel;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
/**
* The outputs of a GRU cell ({@link GRUCell}.
*/
@Getter
public class SRULayerOutputs {
/**
* Current cell output [batchSize, inSize, timeSeriesLength].
*/
private SDVariable h;
/**
* Current cell state [batchSize, inSize, timeSeriesLength].
*/
private SDVariable c;
public SRULayerOutputs(SDVariable[] outputs){
Preconditions.checkArgument(outputs.length == 2,
"Must have 2 SRU cell outputs, got %s", outputs.length);
h = outputs[0];
c = outputs[1];
}
/**
* Get all outputs returned by the cell.
*/
public List<SDVariable> getAllOutputs(){
return Arrays.asList(h, c);
}
/**
* Get h, the output of the cell.
*
* Has shape [batchSize, inSize, timeSeriesLength].
*/
public SDVariable getOutput(){
return h;
}
/**
* Get c, the state of the cell.
*
* Has shape [batchSize, inSize, timeSeriesLength].
*/
public SDVariable getState(){
return c;
}
private SDVariable lastOutput = null;
/**
* Get y, the output of the cell, for the last time step.
*
* Has shape [batchSize, inSize].
*/
public SDVariable getLastOutput(){
if(lastOutput != null)
return lastOutput;
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
return lastOutput;
}
private SDVariable lastState = null;
/**
* Get c, the state of the cell, for the last time step.
*
* Has shape [batchSize, inSize].
*/
public SDVariable getLastState(){
if(lastState != null)
return lastState;
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
return lastState;
}
}

View File

@ -0,0 +1,51 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
/**
* The weight configuration of a GRU cell. For {@link GRUCell}.
*
*/
@EqualsAndHashCode(callSuper = true)
@Data
@Builder
public class GRUWeights extends RNNWeights {
/**
* Reset and Update gate weights, with a shape of [inSize + numUnits, 2*numUnits].
*
* The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset.
*/
@NonNull
private SDVariable ruWeight;
/**
* Cell gate weights, with a shape of [inSize + numUnits, numUnits]
*/
@NonNull
private SDVariable cWeight;
/**
* Reset and Update gate bias, with a shape of [2*numUnits]. May be null.
*
* The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset.
*/
@NonNull
private SDVariable ruBias;
/**
* Cell gate bias, with a shape of [numUnits]. May be null.
*/
@NonNull
private SDVariable cBias;
@Override
public SDVariable[] args() {
return filterNonNull(ruWeight, cWeight, ruBias, cBias);
}
}

View File

@ -0,0 +1,57 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
/**
* The weight configuration of a LSTM layer. For {@link LSTMLayer} and {@link LSTMBlockCell}.
*
*/
@EqualsAndHashCode(callSuper = true)
@Data
@Builder
public class LSTMWeights extends RNNWeights {
/**
* Input to hidden weights and hidden to hidden weights, with a shape of [inSize + numUnits, 4*numUnits].
*
* Input to hidden and hidden to hidden are concatenated in dimension 0,
* so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :].
*/
@NonNull
private SDVariable weights;
/**
* Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits].
*/
@NonNull
private SDVariable inputPeepholeWeights;
/**
* Cell peephole (t-1) connections to forget gate, with a shape of [numUnits].
*/
@NonNull
private SDVariable forgetPeepholeWeights;
/**
* Cell peephole (t) connections to output gate, with a shape of [numUnits].
*/
@NonNull
private SDVariable outputPeepholeWeights;
/**
* Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits].
*/
@NonNull
private SDVariable bias;
@Override
public SDVariable[] args() {
return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias);
}
}

View File

@ -0,0 +1,35 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
import java.util.Arrays;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.util.ArrayUtil;
public abstract class RNNWeights {
public abstract SDVariable[] args();
protected static SDVariable[] filterNonNull(SDVariable... args){
int count = 0;
for(SDVariable v : args){
if(v != null){
count++;
}
}
SDVariable[] res = new SDVariable[count];
int i = 0;
for(SDVariable v : args){
if(v != null){
res[i] = v;
i++;
}
}
return res;
}
public SDVariable[] argsWithInputs(SDVariable... inputs){
return ArrayUtil.combine(inputs, args());
}
}

View File

@ -0,0 +1,37 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
/**
* The weight configuration of a SRU layer. For {@link SRU} and {@link SRUCell}.
*
*/
@EqualsAndHashCode(callSuper = true)
@Data
@Builder
public class SRUWeights extends RNNWeights {
/**
* Weights, with shape [inSize, 3*inSize].
*/
@NonNull
private SDVariable weights;
/**
* Biases, with shape [2*inSize].
*/
@NonNull
private SDVariable bias;
@Override
public SDVariable[] args() {
return new SDVariable[]{weights, bias};
}
}

View File

@ -16,14 +16,19 @@
package org.nd4j.autodiff.opvalidation; package org.nd4j.autodiff.opvalidation;
import java.util.Arrays;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.Test; import org.junit.Test;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -59,23 +64,18 @@ public class RnnOpValidation extends BaseOpValidation {
SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut)); SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut));
double fb = 1.0; double fb = 1.0;
LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder() LSTMConfiguration conf = LSTMConfiguration.builder()
.xt(x)
.cLast(cLast)
.yLast(yLast)
.W(W)
.Wci(Wci)
.Wcf(Wcf)
.Wco(Wco)
.b(b)
.peepHole(true) .peepHole(true)
.forgetBias(fb) .forgetBias(fb)
.clippingCellValue(0.0) .clippingCellValue(0.0)
.build(); .build();
List<SDVariable> v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v){ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.getVarName()); toExec.add(sdv.getVarName());
} }
@ -167,23 +167,18 @@ public class RnnOpValidation extends BaseOpValidation {
SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8)); SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8));
double fb = 1.0; double fb = 1.0;
LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder() LSTMConfiguration conf = LSTMConfiguration.builder()
.xt(x)
.cLast(cLast)
.yLast(yLast)
.W(W)
.Wci(Wci)
.Wcf(Wcf)
.Wco(Wco)
.b(b)
.peepHole(false) .peepHole(false)
.forgetBias(fb) .forgetBias(fb)
.clippingCellValue(0.0) .clippingCellValue(0.0)
.build(); .build();
List<SDVariable> v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v){ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.getVarName()); toExec.add(sdv.getVarName());
} }
@ -228,16 +223,14 @@ public class RnnOpValidation extends BaseOpValidation {
SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut)); SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
double fb = 1.0; double fb = 1.0;
GRUCellConfiguration conf = GRUCellConfiguration.builder() GRUWeights weights = GRUWeights.builder()
.xt(x) .ruWeight(Wru)
.hLast(hLast) .cWeight(Wc)
.Wru(Wru) .ruBias(bru)
.Wc(Wc) .cBias(bc)
.bru(bru)
.bc(bc)
.build(); .build();
List<SDVariable> v = sd.rnn().gru("gru", conf); List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs();
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v){ for(SDVariable sdv : v){
toExec.add(sdv.getVarName()); toExec.add(sdv.getVarName());

View File

@ -23,6 +23,7 @@ import lombok.NoArgsConstructor;
import java.io.Serializable; import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import org.nd4j.base.Preconditions;
/** /**
* Simple pair implementation * Simple pair implementation
@ -86,4 +87,10 @@ public class Pair<K, V> implements Serializable {
public static <T, E> Pair<T,E> pairOf(T key, E value) { public static <T, E> Pair<T,E> pairOf(T key, E value) {
return new Pair<T, E>(key, value); return new Pair<T, E>(key, value);
} }
public static <T> Pair<T, T> fromArray(T[] arr){
Preconditions.checkArgument(arr.length == 2,
"Can only create a pair from an array with two values, got %s", arr.length);
return new Pair<>(arr[0], arr[1]);
}
} }