2023-04-17 10:37:01 +02:00

505 lines
21 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.datavec.api.Record;
import org.datavec.api.Writable;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataIndex;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.*;
import static java.nio.channels.Channels.newChannel;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest {
private static final BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@TempDir
public File testDir;
@Test
public void testToArrayFromINDArray() {
Schema.Builder schemaBuilder = new Schema.Builder();
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4});
Schema schema = schemaBuilder.build();
int numRows = 4;
List<List<Writable>> ret = new ArrayList<>(numRows);
for(int i = 0; i < numRows; i++) {
ret.add(Collections.<Writable>singletonList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4))));
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema);
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
assertArrayEquals(new long[]{4,4},array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4);
assertEquals(assertion,array);
}
@Test
public void testArrowColumnINDArray() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
int numCols = 2;
INDArray arr = Nd4j.linspace(1,4,4);
for(int i = 0; i < numCols; i++) {
schema.addColumnNDArray(String.valueOf(i),new long[]{1,4});
single.add(String.valueOf(i));
}
Schema buildSchema = schema.build();
List<List<Writable>> list = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
for(int i = 0 ; i < numCols; i++) {
firstRow.add(new NDArrayWritable(arr));
}
list.add(firstRow);
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
assertEquals(numCols,fieldVectors.size());
assertEquals(1,fieldVectors.get(0).getValueCount());
assertFalse(fieldVectors.get(0).isNull(0));
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
assertEquals(1,arrowWritableRecordBatch.size());
Writable writable = arrowWritableRecordBatch.get(0).get(0);
assertTrue(writable instanceof NDArrayWritable);
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
assertEquals(arr,ndArrayWritable.get());
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
System.out.println(ndArrayWritablewritable1.get());
}
@Test
public void testArrowColumnString() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i));
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
List<List<Writable>> assertion = new ArrayList<>();
assertion.add(Arrays.asList(new IntWritable(0),new IntWritable(1)));
assertEquals(assertion,records);
List<List<String>> batch = new ArrayList<>();
for(int i = 0; i < 2; i++) {
batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i)));
}
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
List<List<Writable>> assertionBatch = new ArrayList<>();
assertionBatch.add(Arrays.asList(new IntWritable(0),new IntWritable(0)));
assertionBatch.add(Arrays.asList(new IntWritable(1),new IntWritable(1)));
assertEquals(assertionBatch,batchRecords);
}
@Test
public void testArrowBatchSetTime() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
schema.addColumnTime(String.valueOf(i),TimeZone.getDefault());
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(
Arrays.asList(new LongWritable(0),new LongWritable(1)),
Arrays.asList(new LongWritable(2),new LongWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<Writable> assertion = Arrays.asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.asList(new LongWritable(4),new LongWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest);
}
@Test
public void testArrowBatchSet() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(
Arrays.asList(new IntWritable(0),new IntWritable(1)),
Arrays.asList(new IntWritable(2),new IntWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<Writable> assertion = Arrays.asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.asList(new IntWritable(4),new IntWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest);
}
@Test
public void testArrowColumnsStringTimeSeries() {
Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i));
}
for(int i = 0; i < 5; i++) {
List<List<String>> arr = Collections.singletonList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr);
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
assertEquals(3,fieldVectors.size());
assertEquals(5,fieldVectors.get(0).getValueCount());
INDArray exp = Nd4j.create(5, 3);
for( int i = 0; i < 5; i++) {
exp.getRow(i).assign(i);
}
//Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
INDArray arr = ArrowConverter.toArray(wri);
assertArrayEquals(new long[] {5,3}, arr.shape());
assertEquals(exp, arr);
}
@Test
public void testConvertVector() {
Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i));
}
for(int i = 0; i < 5; i++) {
List<List<String>> arr = Collections.singletonList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr);
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0));
assertEquals(5,arr.length());
}
@Test
public void testCreateNDArray() throws Exception {
val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
File f = testDir;
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID() + ".arrorw");
FileOutputStream outputStream = new FileOutputStream(tmpFile);
tmpFile.deleteOnExit();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream);
outputStream.flush();
outputStream.close();
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst());
assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList());
byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read);
//send file
File tmp = tmpDataFile(recordsToWrite);
ArrowRecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp));
recordReader.next();
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
INDArray arr2 = ArrowConverter.toArray(currentBatch);
assertEquals(2,arr2.rows());
assertEquals(2,arr2.columns());
}
@Test
public void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator);
assertEquals(matrix.rows(),vectors.size());
INDArray vector = Nd4j.linspace(1,4,4);
val vectors2 = ArrowConverter.convertToArrowVector(vector, Collections.singletonList("test"), ColumnType.Double,bufferAllocator);
assertEquals(1,vectors2.size());
assertEquals(matrix.length(),vectors2.get(0).getValueCount());
}
@Test
public void testSchemaConversionBasic() {
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) {
schemaBuilder.addColumnDouble("test-" + i);
schemaBuilder.addColumnInteger("testi-" + i);
schemaBuilder.addColumnLong("testl-" + i);
schemaBuilder.addColumnFloat("testf-" + i);
}
Schema schema = schemaBuilder.build();
val schema2 = ArrowConverter.toArrowSchema(schema);
assertEquals(8,schema2.getFields().size());
val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
assertEquals(schema,convertedSchema);
}
@Test
public void testReadSchemaAndRecordsFromByteArray() throws Exception {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
int valueCount = 3;
List<Field> fields = new ArrayList<>();
fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
fields.add(ArrowConverter.intField("field2"));
List<FieldVector> fieldVectors = new ArrayList<>();
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3}));
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3}));
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
vectorUnloader.getRecordBatch();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch();
} catch (IOException e) {
log.error("",e);
}
byte[] arr = byteArrayOutputStream.toByteArray();
val arr2 = ArrowConverter.readFromBytes(arr);
assertEquals(2,arr2.getFirst().numColumns());
assertEquals(3,arr2.getRight().size());
val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight());
assertEquals(2,arrowCols.size());
assertEquals(valueCount,arrowCols.get(0).getValueCount());
}
@Test
public void testVectorForEdgeCases() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE});
assertEquals(Float.MIN_VALUE,vector.get(0),1e-2);
assertEquals(Float.MAX_VALUE,vector.get(1),1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE});
assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2);
assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2);
}
@Test
public void testVectorFor() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3});
assertEquals(3,vector.getValueCount());
assertEquals(1,vector.get(0),1e-2);
assertEquals(2,vector.get(1),1e-2);
assertEquals(3,vector.get(2),1e-2);
val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3});
assertEquals(3,vectorLong.getValueCount());
assertEquals(1,vectorLong.get(0),1e-2);
assertEquals(2,vectorLong.get(1),1e-2);
assertEquals(3,vectorLong.get(2),1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3});
assertEquals(3,vectorInt.getValueCount());
assertEquals(1,vectorInt.get(0),1e-2);
assertEquals(2,vectorInt.get(1),1e-2);
assertEquals(3,vectorInt.get(2),1e-2);
val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3});
assertEquals(3,vectorDouble.getValueCount());
assertEquals(1,vectorDouble.get(0),1e-2);
assertEquals(2,vectorDouble.get(1),1e-2);
assertEquals(3,vectorDouble.get(2),1e-2);
val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false});
assertEquals(3,vectorBool.getValueCount());
assertEquals(1,vectorBool.get(0),1e-2);
assertEquals(1,vectorBool.get(1),1e-2);
assertEquals(0,vectorBool.get(2),1e-2);
}
@Test
public void testRecordReaderAndWriteFile() throws Exception {
val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read);
//send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp));
List<Writable> record = recordReader.next();
assertEquals(2,record.size());
}
@Test
public void testRecordReaderMetaDataList() throws Exception {
val recordsToWrite = recordToWrite();
//send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
recordReader.loadFromMetaData(Collections.<RecordMetaData>singletonList(recordMetaDataIndex));
Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size());
}
@Test
public void testDates() {
Date now = new Date();
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now});
assertEquals(now.getTime(),timeStampMilliVector.get(0));
}
@Test
public void testRecordReaderMetaData() throws Exception {
val recordsToWrite = recordToWrite();
//send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
recordReader.loadFromMetaData(recordMetaDataIndex);
Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size());
}
private File tmpDataFile(Pair<Schema,List<List<Writable>>> recordsToWrite) throws IOException {
File f = testDir;
//send file
File tmp = new File(f,"tmp-file-" + UUID.randomUUID());
tmp.mkdirs();
File tmpFile = new File(tmp,"data.arrow");
tmpFile.deleteOnExit();
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream);
bufferedOutputStream.flush();
bufferedOutputStream.close();
return tmp;
}
private Pair<Schema,List<List<Writable>>> recordToWrite() {
List<List<Writable>> records = new ArrayList<>();
records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) {
schemaBuilder.addColumnFloat("col-" + i);
}
return Pair.of(schemaBuilder.build(),records);
}
}