cavis/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java
2023-04-17 10:37:01 +02:00

143 lines
5.7 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms;
import org.datavec.api.Record;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.inmemory.InMemorySequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.ConditionOp;
import org.datavec.api.transform.condition.column.CategoricalColumnCondition;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.Writable;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class LocalTransformProcessRecordReaderTests {
@Test
public void simpleTransformTest() throws Exception {
Schema schema = new Schema.Builder().addColumnDouble("0").addColumnDouble("1").addColumnDouble("2")
.addColumnDouble("3").addColumnDouble("4").build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("0").build();
CSVRecordReader csvRecordReader = new CSVRecordReader();
csvRecordReader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
LocalTransformProcessRecordReader transformProcessRecordReader =
new LocalTransformProcessRecordReader(csvRecordReader, transformProcess);
assertEquals(4, transformProcessRecordReader.next().size());
}
@Test
public void simpleTransformTestSequence() {
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0),
new IntWritable(0)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1),
new IntWritable(0)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2),
new IntWritable(0)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build();
InMemorySequenceRecordReader inMemorySequenceRecordReader =
new InMemorySequenceRecordReader(Collections.singletonList(sequence));
LocalTransformProcessSequenceRecordReader transformProcessSequenceRecordReader =
new LocalTransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess);
List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord();
assertEquals(2, next.get(0).size());
}
@Test
public void testLocalFilter(){
List<List<Writable>> in = new ArrayList<>();
in.add(Arrays.asList(new Text("Keep"), new IntWritable(0)));
in.add(Arrays.asList(new Text("Remove"), new IntWritable(1)));
in.add(Arrays.asList(new Text("Keep"), new IntWritable(2)));
in.add(Arrays.asList(new Text("Remove"), new IntWritable(3)));
Schema s = new Schema.Builder()
.addColumnCategorical("cat", "Keep", "Remove")
.addColumnInteger("int")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.filter(new CategoricalColumnCondition("cat", ConditionOp.Equal, "Remove"))
.build();
RecordReader rr = new CollectionRecordReader(in);
LocalTransformProcessRecordReader ltprr = new LocalTransformProcessRecordReader(rr, tp);
List<List<Writable>> out = new ArrayList<>();
while(ltprr.hasNext()){
out.add(ltprr.next());
}
List<List<Writable>> exp = Arrays.asList(in.get(0), in.get(2));
assertEquals(exp, out);
//Check reset:
ltprr.reset();
out.clear();
while(ltprr.hasNext()){
out.add(ltprr.next());
}
assertEquals(exp, out);
//Also test Record method:
List<Record> rl = new ArrayList<>();
rr.reset();
while(rr.hasNext()){
rl.add(rr.nextRecord());
}
List<Record> exp2 = Arrays.asList(rl.get(0), rl.get(2));
List<Record> act = new ArrayList<>();
ltprr.reset();
while(ltprr.hasNext()){
act.add(ltprr.nextRecord());
}
}
}