cavis/pydatavec/pydatavec/executors/spark.py

91 lines
3.3 KiB
Python

################################################################################
# Copyright (c) 2015-2018 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
import os
_JVM_RUNNING = False
class StringRDD(object):
def __init__(self, java_rdd):
self.java_rdd = java_rdd
def __iter__(self):
jlist = self.java_rdd.collect()
size = jlist.size()
return iter([jlist.get(i) for i in range(size)])
def iter(self):
return self.__iter__()
def save(self, path):
self.java_rdd.saveAsTextFile(path)
def save_to_csv(self, path):
l = list(self)
with open(path, 'w') as f:
for x in l:
f.write(x + '\n')
class SparkExecutor(object):
def __init__(self, master='local[*]', app_name='pydatavec'):
global _JVM_RUNNING
if not _JVM_RUNNING:
from ..java_classes import SparkConf, SparkContext, SparkTransformExecutor
from ..java_classes import CSVRecordReader, WritablesToStringFunction, StringToWritablesFunction
_JVM_RUNNING = True
spark_conf = SparkConf()
spark_conf.setMaster(master)
spark_conf.setAppName(app_name)
self.spark_context = SparkContext(spark_conf)
self.rr = CSVRecordReader()
self.executor = SparkTransformExecutor
self.str2wf = StringToWritablesFunction
self.w2strf = WritablesToStringFunction
def __call__(self, tp, source):
source_type = getattr(type(source), '__name__', None)
if source_type == 'str':
if os.path.isfile(source) or os.path.isdir(source):
string_data = self.spark_context.textFile(
source) # JavaRDD<String>
else:
raise ValueError('Invalid source ' + source)
elif source_type == 'org.apache.spark.api.java.JavaRDD':
string_data = source
elif source_type.endswith('RDD'):
tempid = 0
path = 'temp_0'
while(os.path.isdir(path)):
tempid += 1
path = 'temp_' + str(tempid)
print('Converting pyspark RDD to JavaRDD...')
source.saveAsTextFile(path)
string_data = self.spark_context.textFile(path)
else:
raise Exception('Unexpected source type: ' + str(type(source)))
parsed_input_data = string_data.map(
self.str2wf(self.rr)) # JavaRDD<List<Writable>>
processed_data = self.executor.execute(
parsed_input_data, tp.to_java()) # JavaRDD<List<Writable>>
processed_as_string = processed_data.map(
self.w2strf(",")) # JavaRDD<String>
return StringRDD(processed_as_string) # StringRDD