2019-06-06 15:21:15 +03:00

416 lines
19 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.analysis.AnalysisCounter;
import org.datavec.api.transform.analysis.DataAnalysis;
import org.datavec.api.transform.analysis.DataVecAnalysisUtils;
import org.datavec.api.transform.analysis.SequenceDataAnalysis;
import org.datavec.api.transform.analysis.columns.*;
import org.datavec.api.transform.analysis.histogram.HistogramCounter;
import org.datavec.api.transform.analysis.quality.QualityAnalysisAddFunction;
import org.datavec.api.transform.analysis.quality.QualityAnalysisCombineFunction;
import org.datavec.api.transform.analysis.quality.QualityAnalysisState;
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.quality.DataQualityAnalysis;
import org.datavec.api.transform.quality.columns.ColumnQuality;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.Comparators;
import org.datavec.spark.transform.analysis.SelectColumnFunction;
import org.datavec.spark.transform.analysis.SequenceFlatMapFunction;
import org.datavec.spark.transform.analysis.SequenceLengthFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisAddFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisCombineFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramAddFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramCombineFunction;
import org.datavec.spark.transform.analysis.seqlength.IntToDoubleFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisAddFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisCounter;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisMergeFunction;
import org.datavec.spark.transform.analysis.unique.UniqueAddFunction;
import org.datavec.spark.transform.analysis.unique.UniqueMergeFunction;
import org.datavec.spark.transform.filter.FilterWritablesBySchemaFunction;
import org.datavec.spark.transform.misc.ColumnToKeyPairTransform;
import org.datavec.spark.transform.misc.SumLongsFunction2;
import org.datavec.spark.transform.misc.comparator.Tuple2Comparator;
import org.datavec.spark.transform.utils.adapter.BiFunctionAdapter;
import scala.Tuple2;
import java.util.*;
/**
* AnalizeSpark: static methods for
* analyzing and
* processing {@code RDD<List<Writable>>} and {@code RDD<List<List<Writable>>}
*
* @author Alex Black
*/
public class AnalyzeSpark {
public static final int DEFAULT_HISTOGRAM_BUCKETS = 30;
public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
return analyzeSequence(schema, data, DEFAULT_HISTOGRAM_BUCKETS);
}
/**
*
* @param schema
* @param data
* @param maxHistogramBuckets
* @return
*/
public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> data,
int maxHistogramBuckets) {
data.cache();
JavaRDD<List<Writable>> fmSeq = data.flatMap(new SequenceFlatMapFunction());
DataAnalysis da = analyze(schema, fmSeq);
//Analyze the length of the sequences:
JavaRDD<Integer> seqLengths = data.map(new SequenceLengthFunction());
seqLengths.cache();
SequenceLengthAnalysisCounter counter = new SequenceLengthAnalysisCounter();
counter = seqLengths.aggregate(counter, new SequenceLengthAnalysisAddFunction(),
new SequenceLengthAnalysisMergeFunction());
int max = counter.getMaxLengthSeen();
int min = counter.getMinLengthSeen();
int nBuckets = counter.getMaxLengthSeen() - counter.getMinLengthSeen();
Tuple2<double[], long[]> hist;
if (max == min) {
//Edge case that spark doesn't like
hist = new Tuple2<>(new double[] {min}, new long[] {counter.getCountTotal()});
} else if (nBuckets < maxHistogramBuckets) {
JavaDoubleRDD drdd = seqLengths.mapToDouble(new IntToDoubleFunction());
hist = drdd.histogram(nBuckets);
} else {
JavaDoubleRDD drdd = seqLengths.mapToDouble(new IntToDoubleFunction());
hist = drdd.histogram(maxHistogramBuckets);
}
seqLengths.unpersist();
SequenceLengthAnalysis lengthAnalysis = SequenceLengthAnalysis.builder()
.totalNumSequences(counter.getCountTotal()).minSeqLength(counter.getMinLengthSeen())
.maxSeqLength(counter.getMaxLengthSeen()).countZeroLength(counter.getCountZeroLength())
.countOneLength(counter.getCountOneLength()).meanLength(counter.getMean())
.histogramBuckets(hist._1()).histogramBucketCounts(hist._2()).build();
return new SequenceDataAnalysis(schema, da.getColumnAnalysis(), lengthAnalysis);
}
/**
* Analyse the specified data - returns a DataAnalysis object with summary information about each column
*
* @param schema Schema for data
* @param data Data to analyze
* @return DataAnalysis for data
*/
public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> data) {
return analyze(schema, data, DEFAULT_HISTOGRAM_BUCKETS);
}
public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> data, int maxHistogramBuckets) {
data.cache();
/*
* TODO: Some care should be given to add histogramBuckets and histogramBucketCounts to this in the future
*/
List<ColumnType> columnTypes = schema.getColumnTypes();
List<AnalysisCounter> counters =
data.aggregate(null, new AnalysisAddFunction(schema), new AnalysisCombineFunction());
double[][] minsMaxes = new double[counters.size()][2];
List<ColumnAnalysis> list = DataVecAnalysisUtils.convertCounters(counters, minsMaxes, columnTypes);
List<HistogramCounter> histogramCounters =
data.aggregate(null, new HistogramAddFunction(maxHistogramBuckets, schema, minsMaxes),
new HistogramCombineFunction());
DataVecAnalysisUtils.mergeCounters(list, histogramCounters);
return new DataAnalysis(schema, list);
}
/**
* Randomly sample values from a single column
*
* @param count Number of values to sample
* @param columnName Name of the column to sample from
* @param schema Schema
* @param data Data to sample from
* @return A list of random samples
*/
public static List<Writable> sampleFromColumn(int count, String columnName, Schema schema,
JavaRDD<List<Writable>> data) {
int colIdx = schema.getIndexOfColumn(columnName);
JavaRDD<Writable> ithColumn = data.map(new SelectColumnFunction(colIdx));
return ithColumn.takeSample(false, count);
}
/**
* Randomly sample values from a single column, in all sequences.
* Values may be taken from any sequence (i.e., sequence order is not preserved)
*
* @param count Number of values to sample
* @param columnName Name of the column to sample from
* @param schema Schema
* @param sequenceData Data to sample from
* @return A list of random samples
*/
public static List<Writable> sampleFromColumnSequence(int count, String columnName, Schema schema,
JavaRDD<List<List<Writable>>> sequenceData) {
JavaRDD<List<Writable>> flattenedSequence = sequenceData.flatMap(new SequenceFlatMapFunction());
return sampleFromColumn(count, columnName, schema, flattenedSequence);
}
/**
* Get a list of unique values from the specified columns.
* For sequence data, use {@link #getUniqueSequence(List, Schema, JavaRDD)}
*
* @param columnName Name of the column to get unique values from
* @param schema Data schema
* @param data Data to get unique values from
* @return List of unique values
*/
public static List<Writable> getUnique(String columnName, Schema schema, JavaRDD<List<Writable>> data) {
int colIdx = schema.getIndexOfColumn(columnName);
JavaRDD<Writable> ithColumn = data.map(new SelectColumnFunction(colIdx));
return ithColumn.distinct().collect();
}
/**
* Get a list of unique values from the specified columns.
* For sequence data, use {@link #getUniqueSequence(String, Schema, JavaRDD)}
*
* @param columnNames Names of the column to get unique values from
* @param schema Data schema
* @param data Data to get unique values from
* @return List of unique values, for each of the specified columns
*/
public static Map<String,List<Writable>> getUnique(List<String> columnNames, Schema schema, JavaRDD<List<Writable>> data){
Map<String,Set<Writable>> m = data.aggregate(null, new UniqueAddFunction(columnNames, schema), new UniqueMergeFunction());
Map<String,List<Writable>> out = new HashMap<>();
for(String s : m.keySet()){
out.put(s, new ArrayList<>(m.get(s)));
}
return out;
}
/**
* Get a list of unique values from the specified column of a sequence
*
* @param columnName Name of the column to get unique values from
* @param schema Data schema
* @param sequenceData Sequence data to get unique values from
* @return
*/
public static List<Writable> getUniqueSequence(String columnName, Schema schema,
JavaRDD<List<List<Writable>>> sequenceData) {
JavaRDD<List<Writable>> flattenedSequence = sequenceData.flatMap(new SequenceFlatMapFunction());
return getUnique(columnName, schema, flattenedSequence);
}
/**
* Get a list of unique values from the specified columns of a sequence
*
* @param columnNames Name of the columns to get unique values from
* @param schema Data schema
* @param sequenceData Sequence data to get unique values from
* @return
*/
public static Map<String,List<Writable>> getUniqueSequence(List<String> columnNames, Schema schema,
JavaRDD<List<List<Writable>>> sequenceData) {
JavaRDD<List<Writable>> flattenedSequence = sequenceData.flatMap(new SequenceFlatMapFunction());
return getUnique(columnNames, schema, flattenedSequence);
}
/**
* Randomly sample a set of examples
*
* @param count Number of samples to generate
* @param data Data to sample from
* @return Samples
*/
public static List<List<Writable>> sample(int count, JavaRDD<List<Writable>> data) {
return data.takeSample(false, count);
}
/**
* Randomly sample a number of sequences from the data
* @param count Number of sequences to sample
* @param data Data to sample from
* @return Sequence samples
*/
public static List<List<List<Writable>>> sampleSequence(int count, JavaRDD<List<List<Writable>>> data) {
return data.takeSample(false, count);
}
/**
* Analyze the data quality of sequence data - provides a report on missing values, values that don't comply with schema, etc
* @param schema Schema for data
* @param data Data to analyze
* @return DataQualityAnalysis object
*/
public static DataQualityAnalysis analyzeQualitySequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
JavaRDD<List<Writable>> fmSeq = data.flatMap(new SequenceFlatMapFunction());
return analyzeQuality(schema, fmSeq);
}
/**
* Analyze the data quality of data - provides a report on missing values, values that don't comply with schema, etc
* @param schema Schema for data
* @param data Data to analyze
* @return DataQualityAnalysis object
*/
public static DataQualityAnalysis analyzeQuality(final Schema schema, final JavaRDD<List<Writable>> data) {
int nColumns = schema.numColumns();
List<QualityAnalysisState> states = data.aggregate(null,
new BiFunctionAdapter<>(new QualityAnalysisAddFunction(schema)),
new BiFunctionAdapter<>(new QualityAnalysisCombineFunction()));
List<ColumnQuality> list = new ArrayList<>(nColumns);
for (QualityAnalysisState qualityState : states) {
list.add(qualityState.getColumnQuality());
}
return new DataQualityAnalysis(schema, list);
}
/**
* Randomly sample a set of invalid values from a specified column.
* Values are considered invalid according to the Schema / ColumnMetaData
*
* @param numToSample Maximum number of invalid values to sample
* @param columnName Same of the column from which to sample invalid values
* @param schema Data schema
* @param data Data
* @return List of invalid examples
*/
public static List<Writable> sampleInvalidFromColumn(int numToSample, String columnName, Schema schema,
JavaRDD<List<Writable>> data) {
return sampleInvalidFromColumn(numToSample, columnName, schema, data, false);
}
/**
* Randomly sample a set of invalid values from a specified column.
* Values are considered invalid according to the Schema / ColumnMetaData
*
* @param numToSample Maximum number of invalid values to sample
* @param columnName Same of the column from which to sample invalid values
* @param schema Data schema
* @param data Data
* @param ignoreMissing If true: ignore missing values (NullWritable or empty/null string) when sampling. If false: include missing values in sampling
* @return List of invalid examples
*/
public static List<Writable> sampleInvalidFromColumn(int numToSample, String columnName, Schema schema,
JavaRDD<List<Writable>> data, boolean ignoreMissing) {
//First: filter out all valid entries, to leave only invalid entries
int colIdx = schema.getIndexOfColumn(columnName);
JavaRDD<Writable> ithColumn = data.map(new SelectColumnFunction(colIdx));
ColumnMetaData meta = schema.getMetaData(columnName);
JavaRDD<Writable> invalid = ithColumn.filter(new FilterWritablesBySchemaFunction(meta, false, ignoreMissing));
return invalid.takeSample(false, numToSample);
}
/**
* Randomly sample a set of invalid values from a specified column, for a sequence data set.
* Values are considered invalid according to the Schema / ColumnMetaData
*
* @param numToSample Maximum number of invalid values to sample
* @param columnName Same of the column from which to sample invalid values
* @param schema Data schema
* @param data Data
* @return List of invalid examples
*/
public static List<Writable> sampleInvalidFromColumnSequence(int numToSample, String columnName, Schema schema,
JavaRDD<List<List<Writable>>> data) {
JavaRDD<List<Writable>> flattened = data.flatMap(new SequenceFlatMapFunction());
return sampleInvalidFromColumn(numToSample, columnName, schema, flattened);
}
/**
* Sample the N most frequently occurring values in the specified column
*
* @param nMostFrequent Top N values to sample
* @param columnName Name of the column to sample from
* @param schema Schema of the data
* @param data RDD containing the data
* @return List of the most frequently occurring Writable objects in that column, along with their counts
*/
public static Map<Writable, Long> sampleMostFrequentFromColumn(int nMostFrequent, String columnName, Schema schema,
JavaRDD<List<Writable>> data) {
int columnIdx = schema.getIndexOfColumn(columnName);
JavaPairRDD<Writable, Long> keyedByWritable = data.mapToPair(new ColumnToKeyPairTransform(columnIdx));
JavaPairRDD<Writable, Long> reducedByWritable = keyedByWritable.reduceByKey(new SumLongsFunction2());
List<Tuple2<Writable, Long>> list =
reducedByWritable.takeOrdered(nMostFrequent, new Tuple2Comparator<Writable>(false));
List<Tuple2<Writable, Long>> sorted = new ArrayList<>(list);
Collections.sort(sorted, new Tuple2Comparator<Writable>(false));
Map<Writable, Long> map = new LinkedHashMap<>();
for (Tuple2<Writable, Long> t2 : sorted) {
map.put(t2._1(), t2._2());
}
return map;
}
/**
* Get the minimum value for the specified column
*
* @param allData All data
* @param columnName Name of the column to get the minimum value for
* @param schema Schema of the data
* @return Minimum value for the column
*/
public static Writable min(JavaRDD<List<Writable>> allData, String columnName, Schema schema){
int columnIdx = schema.getIndexOfColumn(columnName);
JavaRDD<Writable> col = allData.map(new SelectColumnFunction(columnIdx));
return col.min(Comparators.forType(schema.getType(columnName).getWritableType()));
}
/**
* Get the maximum value for the specified column
*
* @param allData All data
* @param columnName Name of the column to get the minimum value for
* @param schema Schema of the data
* @return Maximum value for the column
*/
public static Writable max(JavaRDD<List<Writable>> allData, String columnName, Schema schema){
int columnIdx = schema.getIndexOfColumn(columnName);
JavaRDD<Writable> col = allData.map(new SelectColumnFunction(columnIdx));
return col.max(Comparators.forType(schema.getType(columnName).getWritableType()));
}
}