Add SequenceTrimToLengthTransform (#61)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-16 12:36:12 +10:00 committed by AlexDBlack
parent 9cf28ea6c9
commit 5cf6859fc4
3 changed files with 241 additions and 0 deletions

View File

@ -17,6 +17,7 @@
package org.datavec.api.transform;
import lombok.Data;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.transform.analysis.DataAnalysis;
@ -34,6 +35,7 @@ import org.datavec.api.transform.reduce.IAssociativeReducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.*;
import org.datavec.api.transform.sequence.trim.SequenceTrimToLengthTransform;
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
import org.datavec.api.transform.sequence.window.WindowFunction;
@ -1100,6 +1102,31 @@ public class TransformProcess implements Serializable {
return this;
}
/**
* Trim the sequence to the specified length (number of sequence steps).<br>
* Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified.
*
* @param maxLength Maximum sequence length (number of time steps)
*/
public Builder trimSequenceToLength(int maxLength) {
actionList.add(new DataAction(new SequenceTrimToLengthTransform(maxLength, SequenceTrimToLengthTransform.Mode.TRIM, null)));
return this;
}
/**
* Trim or pad the sequence to the specified length (number of sequence steps).<br>
* Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be
* padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.<br>
* Note that the 'pad' list (i.e., values to pad with) must be equal in length to the number of columns (values per time step)
*
* @param length Required length - trim sequences longer than this, pad sequences shorter than this
* @param pad Values to pad at the end of the sequence
*/
public Builder trimOrPadSequenceToLength(int length, @NonNull List<Writable> pad) {
actionList.add(new DataAction(new SequenceTrimToLengthTransform(length, SequenceTrimToLengthTransform.Mode.TRIM_OR_PAD, pad)));
return this;
}
/**
* Perform a sequence of operation on the specified columns. Note that this also truncates sequences by the
* specified offset amount by default. Use {@code transform(new SequenceOffsetTransform(...)} to change this.

View File

@ -0,0 +1,129 @@
package org.datavec.api.transform.sequence.trim;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.nd4j.base.Preconditions;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.ArrayList;
import java.util.List;
/**
* Trim or pad the sequence to the specified length (number of sequence steps). It supports 2 modes:<br>
* TRIM: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified.<br>
* TRIM_OR_PAD: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be
* padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.<br>
* Note that the 'pad' list (i.e., values to pad when using TRIM_OR_PAD mode) must be equal in length to the number of columns (values per time step)
*
* @author Alex Black
*/
@JsonIgnoreProperties({"schema"})
@EqualsAndHashCode(exclude = {"schema"})
@Data
public class SequenceTrimToLengthTransform implements Transform {
/**
* Mode. See {@link SequenceTrimToLengthTransform}
*/
public enum Mode {TRIM, TRIM_OR_PAD}
private int maxLength;
private Mode mode;
private List<Writable> pad;
private Schema schema;
/**
* @param maxLength maximum sequence length. Must be positive.
* @param mode Mode - trim or trim/pad
* @param pad Padding value. Only used with Mode.TRIM_OR_PAD. Must be equal in length to the number of columns (values per time step)
*/
public SequenceTrimToLengthTransform(@JsonProperty("maxLength") int maxLength, @JsonProperty("mode") Mode mode, @JsonProperty("pad") List<Writable> pad) {
Preconditions.checkState(maxLength > 0, "Maximum length must be > 0, got %s", maxLength);
Preconditions.checkState(mode == Mode.TRIM || pad != null, "If mode == Mode.TRIM_OR_PAD ");
this.maxLength = maxLength;
this.mode = mode;
this.pad = pad;
}
@Override
public List<Writable> map(List<Writable> writables) {
throw new UnsupportedOperationException("SequenceTrimToLengthTransform cannot be applied to non-sequence values");
}
@Override
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
if (mode == Mode.TRIM) {
if (sequence.size() <= maxLength) {
return sequence;
}
return new ArrayList<>(sequence.subList(0, maxLength));
} else {
//Trim or pad
if (sequence.size() == maxLength) {
return sequence;
} else if (sequence.size() > maxLength) {
return new ArrayList<>(sequence.subList(0, maxLength));
} else {
//Need to pad
Preconditions.checkState(sequence.size() == 0 || sequence.get(0).size() == pad.size(), "Invalid padding values: %s padding " +
"values were provided, but data has %s values per time step (columns)", pad.size(), sequence.get(0).size());
List<List<Writable>> out = new ArrayList<>(maxLength);
out.addAll(sequence);
while (out.size() < maxLength) {
out.add(pad);
}
return out;
}
}
}
@Override
public Object map(Object input) {
throw new UnsupportedOperationException();
}
@Override
public Object mapSequence(Object sequence) {
throw new UnsupportedOperationException();
}
@Override
public Schema transform(Schema inputSchema) {
return inputSchema;
}
@Override
public void setInputSchema(Schema inputSchema) {
this.schema = inputSchema;
}
@Override
public Schema getInputSchema() {
return schema;
}
@Override
public String outputColumnName() {
return null;
}
@Override
public String[] outputColumnNames() {
return schema.getColumnNames().toArray(new String[schema.numColumns()]);
}
@Override
public String[] columnNames() {
return outputColumnNames();
}
@Override
public String columnName() {
return null;
}
}

View File

@ -1321,6 +1321,91 @@ public class TestTransforms {
assertEquals(expTrimLast, tLast.mapSequence(seq));
}
@Test
public void testSequenceTrimToLengthTransform(){
List<List<Writable>> seq = Arrays.asList(
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
Arrays.<Writable>asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)));
List<List<Writable>> expTrimLength3 = Arrays.asList(
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)));
Schema s = new Schema.Builder()
.addColumnsDouble("first", "second", "third")
.build();
TransformProcess p = new TransformProcess.Builder(s)
.trimSequenceToLength(3)
.build();
List<List<Writable>> out = p.executeSequence(seq);
assertEquals(expTrimLength3, out);
List<List<Writable>> seq2 = Arrays.asList(
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)));
out = p.executeSequence(seq2);
assertEquals(seq2, out);
String json = p.toJson();
TransformProcess tp2 = TransformProcess.fromJson(json);
assertEquals(expTrimLength3, tp2.executeSequence(seq));
assertEquals(seq2, tp2.executeSequence(seq2));
}
@Test
public void testSequenceTrimToLengthTransformTrimOrPad(){
List<List<Writable>> seq = Arrays.asList(
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
Arrays.<Writable>asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)),
Arrays.<Writable>asList(new DoubleWritable(12), new DoubleWritable(13), new DoubleWritable(14)));
List<List<Writable>> seq2 = Arrays.asList(
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)));
List<List<Writable>> expTrimLength4 = Arrays.asList(
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
Arrays.<Writable>asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)));
Schema s = new Schema.Builder()
.addColumnsDouble("first", "second", "third")
.build();
TransformProcess p = new TransformProcess.Builder(s)
.trimOrPadSequenceToLength(4, Arrays.<Writable>asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)))
.build();
List<List<Writable>> out = p.executeSequence(seq);
assertEquals(expTrimLength4, out);
List<List<Writable>> exp2 = Arrays.asList(
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
Arrays.<Writable>asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)),
Arrays.<Writable>asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)));
out = p.executeSequence(seq2);
assertEquals(exp2, out);
String json = p.toJson();
TransformProcess tp2 = TransformProcess.fromJson(json);
assertEquals(expTrimLength4, tp2.executeSequence(seq));
assertEquals(exp2, tp2.executeSequence(seq2));
}
@Test
public void testSequenceOffsetTransform(){