diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java
index 85905f15d..24a029179 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java
@@ -17,6 +17,7 @@
package org.datavec.api.transform;
import lombok.Data;
+import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.transform.analysis.DataAnalysis;
@@ -34,6 +35,7 @@ import org.datavec.api.transform.reduce.IAssociativeReducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.*;
+import org.datavec.api.transform.sequence.trim.SequenceTrimToLengthTransform;
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
import org.datavec.api.transform.sequence.window.WindowFunction;
@@ -1100,6 +1102,31 @@ public class TransformProcess implements Serializable {
return this;
}
+ /**
+ * Trim the sequence to the specified length (number of sequence steps).
+ * Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified.
+ *
+ * @param maxLength Maximum sequence length (number of time steps)
+ */
+ public Builder trimSequenceToLength(int maxLength) {
+ actionList.add(new DataAction(new SequenceTrimToLengthTransform(maxLength, SequenceTrimToLengthTransform.Mode.TRIM, null)));
+ return this;
+ }
+
+ /**
+ * Trim or pad the sequence to the specified length (number of sequence steps).
+ * Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be
+ * padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.
+ * Note that the 'pad' list (i.e., values to pad with) must be equal in length to the number of columns (values per time step)
+ *
+ * @param length Required length - trim sequences longer than this, pad sequences shorter than this
+ * @param pad Values to pad at the end of the sequence
+ */
+ public Builder trimOrPadSequenceToLength(int length, @NonNull List pad) {
+ actionList.add(new DataAction(new SequenceTrimToLengthTransform(length, SequenceTrimToLengthTransform.Mode.TRIM_OR_PAD, pad)));
+ return this;
+ }
+
/**
* Perform a sequence of operation on the specified columns. Note that this also truncates sequences by the
* specified offset amount by default. Use {@code transform(new SequenceOffsetTransform(...)} to change this.
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java
new file mode 100644
index 000000000..c2527629b
--- /dev/null
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java
@@ -0,0 +1,129 @@
+package org.datavec.api.transform.sequence.trim;
+
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import org.datavec.api.transform.Transform;
+import org.datavec.api.transform.schema.Schema;
+import org.datavec.api.writable.Writable;
+import org.nd4j.base.Preconditions;
+import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Trim or pad the sequence to the specified length (number of sequence steps). It supports 2 modes:
+ * TRIM: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified.
+ * TRIM_OR_PAD: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be
+ * padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.
+ * Note that the 'pad' list (i.e., values to pad when using TRIM_OR_PAD mode) must be equal in length to the number of columns (values per time step)
+ *
+ * @author Alex Black
+ */
+@JsonIgnoreProperties({"schema"})
+@EqualsAndHashCode(exclude = {"schema"})
+@Data
+public class SequenceTrimToLengthTransform implements Transform {
+ /**
+ * Mode. See {@link SequenceTrimToLengthTransform}
+ */
+ public enum Mode {TRIM, TRIM_OR_PAD}
+
+ private int maxLength;
+ private Mode mode;
+ private List pad;
+
+ private Schema schema;
+
+ /**
+ * @param maxLength maximum sequence length. Must be positive.
+ * @param mode Mode - trim or trim/pad
+ * @param pad Padding value. Only used with Mode.TRIM_OR_PAD. Must be equal in length to the number of columns (values per time step)
+ */
+ public SequenceTrimToLengthTransform(@JsonProperty("maxLength") int maxLength, @JsonProperty("mode") Mode mode, @JsonProperty("pad") List pad) {
+ Preconditions.checkState(maxLength > 0, "Maximum length must be > 0, got %s", maxLength);
+ Preconditions.checkState(mode == Mode.TRIM || pad != null, "If mode == Mode.TRIM_OR_PAD ");
+ this.maxLength = maxLength;
+ this.mode = mode;
+ this.pad = pad;
+ }
+
+ @Override
+ public List map(List writables) {
+ throw new UnsupportedOperationException("SequenceTrimToLengthTransform cannot be applied to non-sequence values");
+ }
+
+ @Override
+ public List> mapSequence(List> sequence) {
+ if (mode == Mode.TRIM) {
+ if (sequence.size() <= maxLength) {
+ return sequence;
+ }
+ return new ArrayList<>(sequence.subList(0, maxLength));
+ } else {
+ //Trim or pad
+ if (sequence.size() == maxLength) {
+ return sequence;
+ } else if (sequence.size() > maxLength) {
+ return new ArrayList<>(sequence.subList(0, maxLength));
+ } else {
+ //Need to pad
+ Preconditions.checkState(sequence.size() == 0 || sequence.get(0).size() == pad.size(), "Invalid padding values: %s padding " +
+ "values were provided, but data has %s values per time step (columns)", pad.size(), sequence.get(0).size());
+
+ List> out = new ArrayList<>(maxLength);
+ out.addAll(sequence);
+ while (out.size() < maxLength) {
+ out.add(pad);
+ }
+ return out;
+ }
+ }
+ }
+
+ @Override
+ public Object map(Object input) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Object mapSequence(Object sequence) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Schema transform(Schema inputSchema) {
+ return inputSchema;
+ }
+
+ @Override
+ public void setInputSchema(Schema inputSchema) {
+ this.schema = inputSchema;
+ }
+
+ @Override
+ public Schema getInputSchema() {
+ return schema;
+ }
+
+ @Override
+ public String outputColumnName() {
+ return null;
+ }
+
+ @Override
+ public String[] outputColumnNames() {
+ return schema.getColumnNames().toArray(new String[schema.numColumns()]);
+ }
+
+ @Override
+ public String[] columnNames() {
+ return outputColumnNames();
+ }
+
+ @Override
+ public String columnName() {
+ return null;
+ }
+}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java
index 5aca04000..600ee0b25 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java
@@ -1321,6 +1321,91 @@ public class TestTransforms {
assertEquals(expTrimLast, tLast.mapSequence(seq));
}
+ @Test
+ public void testSequenceTrimToLengthTransform(){
+ List> seq = Arrays.asList(
+ Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
+ Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
+ Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
+ Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)));
+
+ List> expTrimLength3 = Arrays.asList(
+ Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
+ Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
+ Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)));
+
+ Schema s = new Schema.Builder()
+ .addColumnsDouble("first", "second", "third")
+ .build();
+
+ TransformProcess p = new TransformProcess.Builder(s)
+ .trimSequenceToLength(3)
+ .build();
+
+ List> out = p.executeSequence(seq);
+ assertEquals(expTrimLength3, out);
+
+
+ List> seq2 = Arrays.asList(
+ Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
+ Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)));
+
+ out = p.executeSequence(seq2);
+ assertEquals(seq2, out);
+
+ String json = p.toJson();
+ TransformProcess tp2 = TransformProcess.fromJson(json);
+ assertEquals(expTrimLength3, tp2.executeSequence(seq));
+ assertEquals(seq2, tp2.executeSequence(seq2));
+ }
+
+ @Test
+ public void testSequenceTrimToLengthTransformTrimOrPad(){
+ List> seq = Arrays.asList(
+ Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
+ Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
+ Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
+ Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)),
+ Arrays.asList(new DoubleWritable(12), new DoubleWritable(13), new DoubleWritable(14)));
+
+ List> seq2 = Arrays.asList(
+ Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
+ Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)));
+
+ List> expTrimLength4 = Arrays.asList(
+ Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
+ Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
+ Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
+ Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)));
+
+ Schema s = new Schema.Builder()
+ .addColumnsDouble("first", "second", "third")
+ .build();
+
+ TransformProcess p = new TransformProcess.Builder(s)
+ .trimOrPadSequenceToLength(4, Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)))
+ .build();
+
+ List> out = p.executeSequence(seq);
+ assertEquals(expTrimLength4, out);
+
+
+ List> exp2 = Arrays.asList(
+ Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
+ Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
+ Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)),
+ Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)));
+
+ out = p.executeSequence(seq2);
+ assertEquals(exp2, out);
+
+
+ String json = p.toJson();
+ TransformProcess tp2 = TransformProcess.fromJson(json);
+ assertEquals(expTrimLength4, tp2.executeSequence(seq));
+ assertEquals(exp2, tp2.executeSequence(seq2));
+ }
+
@Test
public void testSequenceOffsetTransform(){