2019-06-06 15:21:15 +03:00

144 lines
4.7 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.transform;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.collection.ListStringRecordReader;
import org.datavec.api.split.ListStringSplit;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import java.util.*;
import static org.junit.Assert.assertEquals;
public class TestTransformProcess {
@Test
public void testExecution(){
Schema schema = new Schema.Builder()
.addColumnsString("col")
.addColumnsDouble("col2")
.build();
Map<Character,Integer> m = defaultCharIndex();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.doubleMathOp("col2", MathOp.Add, 1.0)
.build();
List<Writable> in = Arrays.<Writable>asList(new Text("Text"), new DoubleWritable(2.0));
List<Writable> exp = Arrays.<Writable>asList(new Text("Text"), new DoubleWritable(3.0));
List<Writable> out = transformProcess.execute(in);
assertEquals(exp, out);
}
@Test
public void testExecuteToSequence() {
Schema schema = new Schema.Builder()
.addColumnsString("action")
.build();
Map<Character,Integer> m = defaultCharIndex();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.removeAllColumnsExceptFor("action")
.convertToSequence()
.transform(new TextToCharacterIndexTransform("action", "action_sequence", m, true))
.build();
String s = "in text";
List<Writable> input = Collections.<Writable>singletonList(new Text(s));
List<List<Writable>> expSeq = new ArrayList<>(s.length());
for( int i = 0; i<s.length(); i++) {
expSeq.add(Collections.<Writable>singletonList(new IntWritable(m.get(s.charAt(i)))));
}
List<List<Writable>> out = transformProcess.executeToSequence(input);
assertEquals(expSeq, out);
}
@Test
public void testInferColumns() throws Exception {
List<List<String>> categories = Arrays.asList(
Arrays.asList("a","d") ,
Arrays.asList("b","e"),
Arrays.asList("c","f")
);
RecordReader listReader = new ListStringRecordReader();
listReader.initialize(new ListStringSplit(categories));
List<String> inferredSingle = TransformProcess.inferCategories(listReader,0);
assertEquals(3,inferredSingle.size());
listReader.initialize(new ListStringSplit(categories));
Map<Integer, List<String>> integerListMap = TransformProcess.inferCategories(listReader, new int[]{0,1});
for(int i = 0; i < 2; i++) {
assertEquals(3,integerListMap.get(i).size());
}
}
public static Map<Character,Integer> defaultCharIndex() {
Map<Character,Integer> ret = new TreeMap<>();
ret.put('a',0);
ret.put('b',1);
ret.put('c',2);
ret.put('d',3);
ret.put('e',4);
ret.put('f',5);
ret.put('g',6);
ret.put('h',7);
ret.put('i',8);
ret.put('j',9);
ret.put('k',10);
ret.put('l',11);
ret.put('m',12);
ret.put('n',13);
ret.put('o',14);
ret.put('p',15);
ret.put('q',16);
ret.put('r',17);
ret.put('s',18);
ret.put('t',19);
ret.put('u',20);
ret.put('v',21);
ret.put('w',22);
ret.put('x',23);
ret.put('y',24);
ret.put('z',25);
ret.put('/',26);
ret.put(' ',27);
ret.put('(',28);
ret.put(')',29);
return ret;
}
}