diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java index 85905f15d..24a029179 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java @@ -17,6 +17,7 @@ package org.datavec.api.transform; import lombok.Data; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.transform.analysis.DataAnalysis; @@ -34,6 +35,7 @@ import org.datavec.api.transform.reduce.IAssociativeReducer; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.transform.sequence.*; +import org.datavec.api.transform.sequence.trim.SequenceTrimToLengthTransform; import org.datavec.api.transform.sequence.trim.SequenceTrimTransform; import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform; import org.datavec.api.transform.sequence.window.WindowFunction; @@ -1100,6 +1102,31 @@ public class TransformProcess implements Serializable { return this; } + /** + * Trim the sequence to the specified length (number of sequence steps).
+ * Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified. + * + * @param maxLength Maximum sequence length (number of time steps) + */ + public Builder trimSequenceToLength(int maxLength) { + actionList.add(new DataAction(new SequenceTrimToLengthTransform(maxLength, SequenceTrimToLengthTransform.Mode.TRIM, null))); + return this; + } + + /** + * Trim or pad the sequence to the specified length (number of sequence steps).
+ * Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be + * padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.
+ * Note that the 'pad' list (i.e., values to pad with) must be equal in length to the number of columns (values per time step) + * + * @param length Required length - trim sequences longer than this, pad sequences shorter than this + * @param pad Values to pad at the end of the sequence + */ + public Builder trimOrPadSequenceToLength(int length, @NonNull List pad) { + actionList.add(new DataAction(new SequenceTrimToLengthTransform(length, SequenceTrimToLengthTransform.Mode.TRIM_OR_PAD, pad))); + return this; + } + /** * Perform a sequence of operation on the specified columns. Note that this also truncates sequences by the * specified offset amount by default. Use {@code transform(new SequenceOffsetTransform(...)} to change this. diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java new file mode 100644 index 000000000..c2527629b --- /dev/null +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java @@ -0,0 +1,129 @@ +package org.datavec.api.transform.sequence.trim; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import org.datavec.api.transform.Transform; +import org.datavec.api.transform.schema.Schema; +import org.datavec.api.writable.Writable; +import org.nd4j.base.Preconditions; +import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.List; + +/** + * Trim or pad the sequence to the specified length (number of sequence steps). It supports 2 modes:
+ * TRIM: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified.
+ * TRIM_OR_PAD: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be + * padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.
+ * Note that the 'pad' list (i.e., values to pad when using TRIM_OR_PAD mode) must be equal in length to the number of columns (values per time step) + * + * @author Alex Black + */ +@JsonIgnoreProperties({"schema"}) +@EqualsAndHashCode(exclude = {"schema"}) +@Data +public class SequenceTrimToLengthTransform implements Transform { + /** + * Mode. See {@link SequenceTrimToLengthTransform} + */ + public enum Mode {TRIM, TRIM_OR_PAD} + + private int maxLength; + private Mode mode; + private List pad; + + private Schema schema; + + /** + * @param maxLength maximum sequence length. Must be positive. + * @param mode Mode - trim or trim/pad + * @param pad Padding value. Only used with Mode.TRIM_OR_PAD. Must be equal in length to the number of columns (values per time step) + */ + public SequenceTrimToLengthTransform(@JsonProperty("maxLength") int maxLength, @JsonProperty("mode") Mode mode, @JsonProperty("pad") List pad) { + Preconditions.checkState(maxLength > 0, "Maximum length must be > 0, got %s", maxLength); + Preconditions.checkState(mode == Mode.TRIM || pad != null, "If mode == Mode.TRIM_OR_PAD "); + this.maxLength = maxLength; + this.mode = mode; + this.pad = pad; + } + + @Override + public List map(List writables) { + throw new UnsupportedOperationException("SequenceTrimToLengthTransform cannot be applied to non-sequence values"); + } + + @Override + public List> mapSequence(List> sequence) { + if (mode == Mode.TRIM) { + if (sequence.size() <= maxLength) { + return sequence; + } + return new ArrayList<>(sequence.subList(0, maxLength)); + } else { + //Trim or pad + if (sequence.size() == maxLength) { + return sequence; + } else if (sequence.size() > maxLength) { + return new ArrayList<>(sequence.subList(0, maxLength)); + } else { + //Need to pad + Preconditions.checkState(sequence.size() == 0 || sequence.get(0).size() == pad.size(), "Invalid padding values: %s padding " + + "values were provided, but data has %s values per time step (columns)", pad.size(), sequence.get(0).size()); + + List> out = new ArrayList<>(maxLength); + out.addAll(sequence); + while (out.size() < maxLength) { + out.add(pad); + } + return out; + } + } + } + + @Override + public Object map(Object input) { + throw new UnsupportedOperationException(); + } + + @Override + public Object mapSequence(Object sequence) { + throw new UnsupportedOperationException(); + } + + @Override + public Schema transform(Schema inputSchema) { + return inputSchema; + } + + @Override + public void setInputSchema(Schema inputSchema) { + this.schema = inputSchema; + } + + @Override + public Schema getInputSchema() { + return schema; + } + + @Override + public String outputColumnName() { + return null; + } + + @Override + public String[] outputColumnNames() { + return schema.getColumnNames().toArray(new String[schema.numColumns()]); + } + + @Override + public String[] columnNames() { + return outputColumnNames(); + } + + @Override + public String columnName() { + return null; + } +} diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index 5aca04000..600ee0b25 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -1321,6 +1321,91 @@ public class TestTransforms { assertEquals(expTrimLast, tLast.mapSequence(seq)); } + @Test + public void testSequenceTrimToLengthTransform(){ + List> seq = Arrays.asList( + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + + List> expTrimLength3 = Arrays.asList( + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8))); + + Schema s = new Schema.Builder() + .addColumnsDouble("first", "second", "third") + .build(); + + TransformProcess p = new TransformProcess.Builder(s) + .trimSequenceToLength(3) + .build(); + + List> out = p.executeSequence(seq); + assertEquals(expTrimLength3, out); + + + List> seq2 = Arrays.asList( + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); + + out = p.executeSequence(seq2); + assertEquals(seq2, out); + + String json = p.toJson(); + TransformProcess tp2 = TransformProcess.fromJson(json); + assertEquals(expTrimLength3, tp2.executeSequence(seq)); + assertEquals(seq2, tp2.executeSequence(seq2)); + } + + @Test + public void testSequenceTrimToLengthTransformTrimOrPad(){ + List> seq = Arrays.asList( + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)), + Arrays.asList(new DoubleWritable(12), new DoubleWritable(13), new DoubleWritable(14))); + + List> seq2 = Arrays.asList( + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); + + List> expTrimLength4 = Arrays.asList( + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + + Schema s = new Schema.Builder() + .addColumnsDouble("first", "second", "third") + .build(); + + TransformProcess p = new TransformProcess.Builder(s) + .trimOrPadSequenceToLength(4, Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902))) + .build(); + + List> out = p.executeSequence(seq); + assertEquals(expTrimLength4, out); + + + List> exp2 = Arrays.asList( + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)), + Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902))); + + out = p.executeSequence(seq2); + assertEquals(exp2, out); + + + String json = p.toJson(); + TransformProcess tp2 = TransformProcess.fromJson(json); + assertEquals(expTrimLength4, tp2.executeSequence(seq)); + assertEquals(exp2, tp2.executeSequence(seq2)); + } + @Test public void testSequenceOffsetTransform(){