cavis/pydl4j/pydl4j/pydl4j.py

354 lines
11 KiB
Python

################################################################################
# Copyright (c) 2015-2019 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
from .jarmgr import *
from .jarmgr import _MY_DIR
from .pom import *
from .docker import docker_file
import platform
import os
import warnings
import os
from subprocess import call as py_call
import json
def call(arglist):
error = py_call(arglist)
if error:
raise Exception('Subprocess error for command: ' + str(arglist))
_CONFIG_FILE = os.path.join(_MY_DIR, 'config.json')
# Default config
_CONFIG = {
'dl4j_version': '1.0.0-SNAPSHOT',
'dl4j_core': True,
'datavec': True,
'spark': True,
'spark_version': '2',
'scala_version': '2.11',
'nd4j_backend': 'cpu',
'validate_jars': True
}
def _is_sub_set(config1, config2):
# check if config1 is a subset of config2
# if config1 < config2, then we can use config2 jar
# for config1 as well
if config1['dl4j_version'] != config1['dl4j_version']:
return False
if config1['dl4j_core'] > config2['dl4j_core']:
return False
if config1['nd4j_backend'] != config2['nd4j_backend']:
return False
if config1['datavec']:
if not config2['datavec']:
return False
if config1['spark'] > config2['spark']:
return False
if config1['spark_version'] != config2['spark_version']:
return False
if config1['scala_version'] != config2['scala_version']:
return False
return True
def _write_config(filepath=None):
if not filepath:
filepath = _CONFIG_FILE
with open(filepath, 'w') as f:
json.dump(_CONFIG, f)
if os.path.isfile(_CONFIG_FILE):
with open(_CONFIG_FILE, 'r') as f:
_CONFIG.update(json.load(f))
else:
_write_config()
def set_config(config):
_CONFIG.update(config)
_write_config()
def get_config():
return _CONFIG
def validate_config(config=None):
if config is None:
config = _CONFIG
valid_options = {
'spark_version': ['1', '2'],
'scala_version': ['2.10', '2.11'],
'nd4j_backend': ['cpu', 'gpu']
}
for k, vs in valid_options.items():
v = config.get(k)
if v is None:
raise KeyError('Key not found in config : {}.'.format(k))
if v not in vs:
raise ValueError(
'Invalid value {} for key {} in config. Valid values are: {}.'.format(v, k, vs))
# spark 2 does not work with scala 2.10
if config['spark_version'] == '2' and config['scala_version'] == '2.10':
raise ValueError(
'Scala 2.10 does not work with spark 2. Set scala_version to 2.11 in pydl4j config. ')
def _get_context_from_config(config=None):
if not config:
config = _CONFIG
# e.g pydl4j-1.0.0-SNAPSHOT-cpu-core-datavec-spark2-2.11
context = 'pydl4j-{}'.format(config['dl4j_version'])
context += '-' + config['nd4j_backend']
if config['dl4j_core']:
context += '-core'
if config['datavec']:
context += '-datavec'
if config['spark']:
spark_version = config['spark_version']
scala_version = config['scala_version']
context += '-spark' + spark_version + '-' + scala_version
return context
def _get_config_from_context(context):
config = {}
backends = ['cpu', 'gpu']
for b in backends:
if '-' + b in context:
config['nd4j_backend'] = b
config['dl4j_version'] = context.split('-' + b)[0][len('pydl4j-'):]
break
config['dl4j_core'] = '-core' in context
set_defs = False
if '-datavec' in context:
config['datavec'] = True
if '-spark' in context:
config['spark'] = True
sp_sc_ver = context.split('-spark')[1]
sp_ver, sc_ver = sp_sc_ver.split('-')
config['spark_version'] = sp_ver
config['scala_version'] = sc_ver
else:
config['spark'] = False
set_defs = True
else:
config['datavec'] = False
set_defs = True
if set_defs:
config['spark_version'] = '2'
config['scala_version'] = '2.11'
validate_config(config)
return config
set_context(_get_context_from_config())
def create_pom_from_config():
config = get_config()
pom = pom_template()
dl4j_version = config['dl4j_version']
nd4j_backend = config['nd4j_backend']
use_spark = config['spark']
scala_version = config['scala_version']
spark_version = config['spark_version']
use_dl4j_core = config['dl4j_core']
use_datavec = config['datavec']
datavec_deps = datavec_dependencies() if use_datavec else ""
pom = pom.replace('{datavec.dependencies}', datavec_deps)
core_deps = dl4j_core_dependencies() if use_dl4j_core else ""
pom = pom.replace('{dl4j.core.dependencies}', core_deps)
spark_deps = spark_dependencies() if use_spark else ""
pom = pom.replace('{spark.dependencies}', spark_deps)
pom = pom.replace('{dl4j.version}', dl4j_version)
if nd4j_backend == 'cpu':
platform_backend = "nd4j-native-platform"
backend = "nd4j-native"
else:
platform_backend = "nd4j-cuda-9.2-platform"
platform_backend = "nd4j-cuda-9.2"
pom = pom.replace('{nd4j.backend}', backend)
pom = pom.replace('{nd4j.platform.backend}', platform_backend)
if use_spark:
pom = pom.replace('{scala.binary.version}', scala_version)
# this naming convention seems a little off
if "SNAPSHOT" in dl4j_version:
dl4j_version = dl4j_version.replace("-SNAPSHOT", "")
dl4j_spark_version = dl4j_version + "_spark_" + spark_version + "-SNAPSHOT"
else:
dl4j_spark_version = dl4j_version + "_spark_" + spark_version
pom = pom.replace('{dl4j.spark.version}', dl4j_spark_version)
# TODO replace if exists
pom_xml = os.path.join(_MY_DIR, 'pom.xml')
with open(pom_xml, 'w') as pom_file:
pom_file.write(pom)
def docker_build():
docker_path = os.path.join(_MY_DIR, 'Dockerfile')
docker_string = docker_file()
with open(docker_path, 'w') as f:
f.write(docker_string)
call(["docker", "build", _MY_DIR, "-t", "pydl4j"])
def docker_run():
create_pom_from_config()
py_call(["docker", "run", "--mount", "src=" +
_MY_DIR + ",target=/app,type=bind", "pydl4j"])
# docker will build into <context>/target, need to move to context dir
context_dir = get_dir()
config = get_config()
dl4j_version = config['dl4j_version']
jar_name = "pydl4j-{}-bin.jar".format(dl4j_version)
base_target_dir = os.path.join(_MY_DIR, "target")
source = os.path.join(base_target_dir, jar_name)
target = os.path.join(context_dir, jar_name)
_write_config(os.path.join(context_dir, 'config.json'))
if os.path.isfile(target):
os.remove(target)
os.rename(source, target)
def is_docker_available():
devnull = open(os.devnull, 'w')
try:
py_call(["docker", "--help"], stdout=devnull, stderr=devnull)
return True
except Exception:
return False
def _maven_build(use_docker):
if use_docker:
docker_build()
docker_run()
else:
create_pom_from_config()
pom_xml = os.path.join(_MY_DIR, 'pom.xml')
command = 'mvn clean install -f ' + pom_xml
os.system(command)
version = _CONFIG['dl4j_version']
jar_name = "pydl4j-{}-bin.jar".format(version)
source = os.path.join(_MY_DIR, 'target', jar_name)
target = os.path.join(get_dir(), jar_name)
if os.path.isfile(target):
os.remove(target)
os.rename(source, target)
def maven_build():
if is_docker_available():
print("Docker available. Starting build...")
_maven_build(use_docker=True)
else:
warnings.warn(
"Docker unavailable. Attempting alternate implementation.")
_maven_build(use_docker=False)
def validate_jars():
if not _CONFIG['validate_jars']:
return
# builds jar if not available for given context
jars = get_jars()
dl4j_version = _CONFIG['dl4j_version']
jar = "pydl4j-{}-bin.jar".format(dl4j_version)
if jar not in jars:
# jar not found
# but its possible a jar exists in a different
# context. If that context is a "super set" of
# of the current one, we can use its jar!
original_context = context()
contexts = _get_all_contexts()
found_super_set_jar = False
for c in contexts:
config = _get_config_from_context(c)
if _is_sub_set(_CONFIG, config):
set_context(c)
jars = get_jars()
if jar in jars:
found_super_set_jar = True
break
if not found_super_set_jar:
set_context(original_context)
print("pdl4j: required uberjar not found, building with docker...")
maven_build()
def validate_nd4j_jars():
validate_jars()
def validate_datavec_jars():
if not _CONFIG['datavec']:
_CONFIG['datavec'] = True
_write_config()
context = _get_context_from_config()
set_context(context)
validate_jars()
def _get_all_contexts():
c = os.listdir(_MY_DIR)
return [x for x in c if x.startswith('pydl4j')]
def set_jnius_config():
try:
import jnius_config
path = get_dir()
if path[-1] == '*':
jnius_config.add_classpath(path)
elif os.path.isfile(path):
jnius_config.add_classpath(path)
else:
path = os.path.join(path, '*')
jnius_config.add_classpath(path)
# Further options can be set by individual projects
except ImportError:
warnings.warn('Pyjnius not installed.')
def add_classpath(path):
try:
import jnius_config
jnius_config.add_classpath(path)
except ImportError:
warnings.warn('Pyjnius not installed.')
set_jnius_config()