parent
9cf28ea6c9
commit
5cf6859fc4
|
@ -17,6 +17,7 @@
|
|||
package org.datavec.api.transform;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.transform.analysis.DataAnalysis;
|
||||
|
@ -34,6 +35,7 @@ import org.datavec.api.transform.reduce.IAssociativeReducer;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.schema.SequenceSchema;
|
||||
import org.datavec.api.transform.sequence.*;
|
||||
import org.datavec.api.transform.sequence.trim.SequenceTrimToLengthTransform;
|
||||
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
|
||||
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
|
||||
import org.datavec.api.transform.sequence.window.WindowFunction;
|
||||
|
@ -1100,6 +1102,31 @@ public class TransformProcess implements Serializable {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim the sequence to the specified length (number of sequence steps).<br>
|
||||
* Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified.
|
||||
*
|
||||
* @param maxLength Maximum sequence length (number of time steps)
|
||||
*/
|
||||
public Builder trimSequenceToLength(int maxLength) {
|
||||
actionList.add(new DataAction(new SequenceTrimToLengthTransform(maxLength, SequenceTrimToLengthTransform.Mode.TRIM, null)));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim or pad the sequence to the specified length (number of sequence steps).<br>
|
||||
* Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be
|
||||
* padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.<br>
|
||||
* Note that the 'pad' list (i.e., values to pad with) must be equal in length to the number of columns (values per time step)
|
||||
*
|
||||
* @param length Required length - trim sequences longer than this, pad sequences shorter than this
|
||||
* @param pad Values to pad at the end of the sequence
|
||||
*/
|
||||
public Builder trimOrPadSequenceToLength(int length, @NonNull List<Writable> pad) {
|
||||
actionList.add(new DataAction(new SequenceTrimToLengthTransform(length, SequenceTrimToLengthTransform.Mode.TRIM_OR_PAD, pad)));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform a sequence of operation on the specified columns. Note that this also truncates sequences by the
|
||||
* specified offset amount by default. Use {@code transform(new SequenceOffsetTransform(...)} to change this.
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
package org.datavec.api.transform.sequence.trim;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Trim or pad the sequence to the specified length (number of sequence steps). It supports 2 modes:<br>
|
||||
* TRIM: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will not be modified.<br>
|
||||
* TRIM_OR_PAD: Sequences longer than the specified maximum will be trimmed to exactly the maximum. Shorter sequences will be
|
||||
* padded with as many copies of the "pad" array to make the sequence length equal the specified maximum.<br>
|
||||
* Note that the 'pad' list (i.e., values to pad when using TRIM_OR_PAD mode) must be equal in length to the number of columns (values per time step)
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonIgnoreProperties({"schema"})
|
||||
@EqualsAndHashCode(exclude = {"schema"})
|
||||
@Data
|
||||
public class SequenceTrimToLengthTransform implements Transform {
|
||||
/**
|
||||
* Mode. See {@link SequenceTrimToLengthTransform}
|
||||
*/
|
||||
public enum Mode {TRIM, TRIM_OR_PAD}
|
||||
|
||||
private int maxLength;
|
||||
private Mode mode;
|
||||
private List<Writable> pad;
|
||||
|
||||
private Schema schema;
|
||||
|
||||
/**
|
||||
* @param maxLength maximum sequence length. Must be positive.
|
||||
* @param mode Mode - trim or trim/pad
|
||||
* @param pad Padding value. Only used with Mode.TRIM_OR_PAD. Must be equal in length to the number of columns (values per time step)
|
||||
*/
|
||||
public SequenceTrimToLengthTransform(@JsonProperty("maxLength") int maxLength, @JsonProperty("mode") Mode mode, @JsonProperty("pad") List<Writable> pad) {
|
||||
Preconditions.checkState(maxLength > 0, "Maximum length must be > 0, got %s", maxLength);
|
||||
Preconditions.checkState(mode == Mode.TRIM || pad != null, "If mode == Mode.TRIM_OR_PAD ");
|
||||
this.maxLength = maxLength;
|
||||
this.mode = mode;
|
||||
this.pad = pad;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Writable> map(List<Writable> writables) {
|
||||
throw new UnsupportedOperationException("SequenceTrimToLengthTransform cannot be applied to non-sequence values");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
|
||||
if (mode == Mode.TRIM) {
|
||||
if (sequence.size() <= maxLength) {
|
||||
return sequence;
|
||||
}
|
||||
return new ArrayList<>(sequence.subList(0, maxLength));
|
||||
} else {
|
||||
//Trim or pad
|
||||
if (sequence.size() == maxLength) {
|
||||
return sequence;
|
||||
} else if (sequence.size() > maxLength) {
|
||||
return new ArrayList<>(sequence.subList(0, maxLength));
|
||||
} else {
|
||||
//Need to pad
|
||||
Preconditions.checkState(sequence.size() == 0 || sequence.get(0).size() == pad.size(), "Invalid padding values: %s padding " +
|
||||
"values were provided, but data has %s values per time step (columns)", pad.size(), sequence.get(0).size());
|
||||
|
||||
List<List<Writable>> out = new ArrayList<>(maxLength);
|
||||
out.addAll(sequence);
|
||||
while (out.size() < maxLength) {
|
||||
out.add(pad);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object map(Object input) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object mapSequence(Object sequence) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Schema transform(Schema inputSchema) {
|
||||
return inputSchema;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInputSchema(Schema inputSchema) {
|
||||
this.schema = inputSchema;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Schema getInputSchema() {
|
||||
return schema;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String outputColumnName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] outputColumnNames() {
|
||||
return schema.getColumnNames().toArray(new String[schema.numColumns()]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] columnNames() {
|
||||
return outputColumnNames();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String columnName() {
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -1321,6 +1321,91 @@ public class TestTransforms {
|
|||
assertEquals(expTrimLast, tLast.mapSequence(seq));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSequenceTrimToLengthTransform(){
|
||||
List<List<Writable>> seq = Arrays.asList(
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)));
|
||||
|
||||
List<List<Writable>> expTrimLength3 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)));
|
||||
|
||||
Schema s = new Schema.Builder()
|
||||
.addColumnsDouble("first", "second", "third")
|
||||
.build();
|
||||
|
||||
TransformProcess p = new TransformProcess.Builder(s)
|
||||
.trimSequenceToLength(3)
|
||||
.build();
|
||||
|
||||
List<List<Writable>> out = p.executeSequence(seq);
|
||||
assertEquals(expTrimLength3, out);
|
||||
|
||||
|
||||
List<List<Writable>> seq2 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)));
|
||||
|
||||
out = p.executeSequence(seq2);
|
||||
assertEquals(seq2, out);
|
||||
|
||||
String json = p.toJson();
|
||||
TransformProcess tp2 = TransformProcess.fromJson(json);
|
||||
assertEquals(expTrimLength3, tp2.executeSequence(seq));
|
||||
assertEquals(seq2, tp2.executeSequence(seq2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSequenceTrimToLengthTransformTrimOrPad(){
|
||||
List<List<Writable>> seq = Arrays.asList(
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(12), new DoubleWritable(13), new DoubleWritable(14)));
|
||||
|
||||
List<List<Writable>> seq2 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)));
|
||||
|
||||
List<List<Writable>> expTrimLength4 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)));
|
||||
|
||||
Schema s = new Schema.Builder()
|
||||
.addColumnsDouble("first", "second", "third")
|
||||
.build();
|
||||
|
||||
TransformProcess p = new TransformProcess.Builder(s)
|
||||
.trimOrPadSequenceToLength(4, Arrays.<Writable>asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)))
|
||||
.build();
|
||||
|
||||
List<List<Writable>> out = p.executeSequence(seq);
|
||||
assertEquals(expTrimLength4, out);
|
||||
|
||||
|
||||
List<List<Writable>> exp2 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)),
|
||||
Arrays.<Writable>asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)));
|
||||
|
||||
out = p.executeSequence(seq2);
|
||||
assertEquals(exp2, out);
|
||||
|
||||
|
||||
String json = p.toJson();
|
||||
TransformProcess tp2 = TransformProcess.fromJson(json);
|
||||
assertEquals(expTrimLength4, tp2.executeSequence(seq));
|
||||
assertEquals(exp2, tp2.executeSequence(seq2));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSequenceOffsetTransform(){
|
||||
|
|
Loading…
Reference in New Issue