cavis/jumpy/tests/jumpy/test_spark.py

127 lines
3.9 KiB
Python

################################################################################
# Copyright (c) 2015-2019 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
from numpy.testing import assert_allclose
from jumpy.spark import py2javaArrayRDD
from jumpy.spark import py2javaDatasetRDD
from jumpy.spark import java2pyArrayRDD
from jumpy.spark import java2pyDatasetRDD
from jumpy.java_classes import JDataset
from jumpy.spark import Dataset
from jumpy.java_classes import ArrayList
from numpy.testing import assert_allclose
from jnius import autoclass
import jumpy as jp
import numpy as np
import pyspark
import pytest
SparkConf = autoclass('org.apache.spark.SparkConf')
SparkContext = autoclass('org.apache.spark.api.java.JavaSparkContext')
class TestSparkConverters(object):
@pytest.fixture(scope='module')
def java_sc(self):
config = SparkConf()
config.setAppName("test")
config.setMaster("local[*]")
return SparkContext(config)
@pytest.fixture(scope='module')
def py_sc(self):
return pyspark.SparkContext(master='local[*]', appName='test')
def test_java2py_array(self, java_sc, py_sc):
data = ArrayList()
for _ in range(100):
arr = jp.array(np.random.random((32, 20))).array
data.add(arr)
java_rdd = java_sc.parallelize(data)
py_rdd = java2pyArrayRDD(java_rdd, py_sc)
data2 = py_rdd.collect()
data = [data.get(i) for i in range(data.size())]
assert len(data) == len(data2)
for d1, d2 in zip(data, data2):
assert_allclose(jp.array(d1).numpy(), d2)
def test_py2java_array(self, java_sc, py_sc):
data = [np.random.random((32, 20)) for _ in range(100)]
jdata = [jp.array(x) for x in data] # required
py_rdd = py_sc.parallelize(data)
java_rdd = py2javaArrayRDD(py_rdd, java_sc)
data2 = java_rdd.collect()
data2 = [data2.get(i) for i in range(data2.size())]
assert len(data) == len(data2)
for d1, d2 in zip(data, data2):
d2 = jp.array(d2).numpy()
assert_allclose(d1, d2)
def test_java2py_dataset(self, java_sc, py_sc):
data = ArrayList()
for _ in range(100):
arr = jp.array(np.random.random((32, 20))).array
ds = JDataset(arr, arr)
data.add(ds)
java_rdd = java_sc.parallelize(data)
py_rdd = java2pyDatasetRDD(java_rdd, py_sc)
data2 = py_rdd.collect()
data = [data.get(i) for i in range(data.size())]
assert len(data) == len(data2)
for d1, d2 in zip(data, data2):
assert_allclose(jp.array(d1.getFeatures()).numpy(), d2.features.numpy())
def test_py2java_array(self, java_sc, py_sc):
data = [np.random.random((32, 20)) for _ in range(100)]
jdata = [jp.array(x) for x in data] # required
data = [Dataset(x, x) for x in data]
py_rdd = py_sc.parallelize(data)
java_rdd = py2javaDatasetRDD(py_rdd, java_sc)
data2 = java_rdd.collect()
data2 = [data2.get(i) for i in range(data2.size())]
assert len(data) == len(data2)
for d1, d2 in zip(data, data2):
d2 = jp.array(d2.getFeatures()).numpy()
assert_allclose(d1.features.numpy(), d2)
if __name__ == '__main__':
pytest.main([__file__])