From 881a672fa13f3b37177eb2ea023e39b0de893645 Mon Sep 17 00:00:00 2001 From: Alexei KLENIN Date: Mon, 5 Oct 2020 19:25:01 -0700 Subject: [PATCH] FEATURE: possibility to define global Classload for ND4J (#8972) Signed-off-by: hosuaby --- .../org/datavec/api/util/ReflectionUtils.java | 18 +- .../autodiff/validation/OpValidation.java | 11 +- .../compression/BasicNDArrayCompressor.java | 3 +- .../serializer/NormalizerSerializer.java | 8 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 46 +- .../org/nd4j/linalg/factory/Nd4jBackend.java | 16 +- .../serde/json/BaseLegacyDeserializer.java | 12 +- .../java/org/nd4j/systeminfo/SystemInfo.java | 44 +- .../org/nd4j/versioncheck/VersionCheck.java | 35 +- .../org/nd4j/nativeblas/NativeOpsHolder.java | 8 +- .../java/org/nd4j/linalg/BaseNd4jTest.java | 45 +- .../java/org/nd4j/linalg/Nd4jTestSuite.java | 17 +- .../org/nd4j/common/base/Preconditions.java | 24 +- .../nd4j/common/config/ND4JClassLoading.java | 61 ++ .../org/nd4j/common/io/ClassPathResource.java | 7 +- .../java/org/nd4j/common/io/ClassUtils.java | 710 ------------------ .../org/nd4j/common/io/ReflectionUtils.java | 35 + .../org/nd4j/common/io/ResourceUtils.java | 20 +- .../org/nd4j/common/resources/Resources.java | 12 +- .../org/nd4j/common/util/ReflectionUtils.java | 122 --- .../nd4j/jdbc/driverfinder/DriverFinder.java | 30 +- .../org/nd4j/jdbc/hsql/HSqlLoaderTest.java | 8 +- .../distributed/messages/VoidMessage.java | 18 +- .../distributed/training/TrainerProvider.java | 12 +- .../ParameterServerSubscriber.java | 12 +- 25 files changed, 288 insertions(+), 1046 deletions(-) create mode 100644 nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java delete mode 100644 nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassUtils.java delete mode 100644 nd4j/nd4j-common/src/main/java/org/nd4j/common/util/ReflectionUtils.java diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java index 100e72c44..a34a86a27 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java @@ -25,14 +25,13 @@ import org.datavec.api.io.serializers.SerializationFactory; import org.datavec.api.io.serializers.Serializer; import java.io.IOException; -import java.lang.reflect.Constructor; import java.lang.reflect.Method; /** - * @deprecated Use {@link org.nd4j.common.util.ReflectionUtils} + * @deprecated Use {@link org.nd4j.common.io.ReflectionUtils} */ @Deprecated -public class ReflectionUtils extends org.nd4j.common.util.ReflectionUtils { +public class ReflectionUtils { private static final Class[] EMPTY_ARRAY = new Class[] {}; private static SerializationFactory serialFactory = null; @@ -48,18 +47,7 @@ public class ReflectionUtils extends org.nd4j.common.util.ReflectionUtils { */ @SuppressWarnings("unchecked") public static T newInstance(Class theClass, Configuration conf) { - T result; - try { - Constructor meth = (Constructor) CONSTRUCTOR_CACHE.get(theClass); - if (meth == null) { - meth = theClass.getDeclaredConstructor(EMPTY_ARRAY); - meth.setAccessible(true); - CONSTRUCTOR_CACHE.put(theClass, meth); - } - result = meth.newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } + T result = org.nd4j.common.io.ReflectionUtils.newInstance(theClass); setConf(result, conf); return result; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index e9aea8968..38c7cbf07 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.validation; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; @@ -613,14 +614,8 @@ public class OpValidation { allOps = new ArrayList<>(gradCheckCoverageCountPerClass.keySet()); for (ClassPath.ClassInfo c : info) { //Load method: Loads (but doesn't link or initialize) the class. - Class clazz; - try { - clazz = Class.forName(c.getName()); - } catch (ClassNotFoundException e) { - //Should never happen as this was found on the classpath - throw new RuntimeException(e); - } - + Class clazz = ND4JClassLoading.loadClassByName(c.getName()); + Objects.requireNonNull(clazz); if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || !DifferentialFunction.class.isAssignableFrom(clazz)) continue; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java index 665d6ce19..0ada386d7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.compression; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -50,7 +51,7 @@ public class BasicNDArrayCompressor { */ codecs = new ConcurrentHashMap<>(); - ServiceLoader loader = ServiceLoader.load(NDArrayCompressor.class); + ServiceLoader loader = ND4JClassLoading.loadService(NDArrayCompressor.class); for (NDArrayCompressor compressor : loader) { codecs.put(compressor.getDescriptor().toUpperCase(), compressor); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java index cf8e5fa99..1dd76e664 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.dataset.api.preprocessor.serializer; import lombok.NonNull; import lombok.Value; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; import java.io.*; @@ -215,7 +216,7 @@ public class NormalizerSerializer { * @throws IOException * @throws IllegalArgumentException if the data format is invalid */ - private Header parseHeader(InputStream stream) throws IOException, ClassNotFoundException { + private Header parseHeader(InputStream stream) throws IOException { DataInputStream dis = new DataInputStream(stream); // Check if the stream starts with the expected header String header = dis.readUTF(); @@ -237,8 +238,9 @@ public class NormalizerSerializer { if (type.equals(NormalizerType.CUSTOM)) { // For custom serializers, the next value is a string with the class opName String strategyClassName = dis.readUTF(); - //noinspection unchecked - return new Header(type, (Class) Class.forName(strategyClassName)); + Class strategyClass = ND4JClassLoading + .loadClassByName(strategyClassName); + return new Header(type, strategyClass); } else { return new Header(type, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 12a76f2b9..48e0855e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.factory; import lombok.extern.slf4j.Slf4j; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.factory.ops.*; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; @@ -5135,37 +5136,36 @@ public class Nd4j { compressDebug = pp.toBoolean(COMPRESSION_DEBUG); char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C); - Class affinityManagerClazz = (Class) Class - .forName(pp.toString(AFFINITY_MANAGER)); + Class affinityManagerClazz = ND4JClassLoading + .loadClassByName(pp.toString(AFFINITY_MANAGER)); affinityManager = affinityManagerClazz.newInstance(); - Class ndArrayFactoryClazz = (Class) Class.forName( - pp.toString(NDARRAY_FACTORY_CLASS)); - Class convolutionInstanceClazz = (Class) Class - .forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); + Class ndArrayFactoryClazz = ND4JClassLoading + .loadClassByName(pp.toString(NDARRAY_FACTORY_CLASS)); + Class convolutionInstanceClazz = ND4JClassLoading + .loadClassByName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory"); - Class dataBufferFactoryClazz = (Class) Class - .forName(pp.toString(DATA_BUFFER_OPS, defaultName)); - Class shapeInfoProviderClazz = (Class) Class - .forName(pp.toString(SHAPEINFO_PROVIDER)); + Class dataBufferFactoryClazz = ND4JClassLoading + .loadClassByName(pp.toString(DATA_BUFFER_OPS, defaultName)); + Class shapeInfoProviderClazz = ND4JClassLoading + .loadClassByName(pp.toString(SHAPEINFO_PROVIDER)); - Class constantProviderClazz = (Class) Class - .forName(pp.toString(CONSTANT_PROVIDER)); + Class constantProviderClazz = ND4JClassLoading + .loadClassByName(pp.toString(CONSTANT_PROVIDER)); - Class memoryManagerClazz = (Class) Class - .forName(pp.toString(MEMORY_MANAGER)); + Class memoryManagerClazz = ND4JClassLoading + .loadClassByName(pp.toString(MEMORY_MANAGER)); allowsOrder = backend.allowsOrder(); String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName()); - Class randomClazz = (Class) Class.forName(rand); + Class randomClazz = ND4JClassLoading.loadClassByName(rand); randomFactory = new RandomFactory(randomClazz); - Class workspaceManagerClazz = (Class) Class - .forName(pp.toString(WORKSPACE_MANAGER)); + Class workspaceManagerClazz = ND4JClassLoading + .loadClassByName(pp.toString(WORKSPACE_MANAGER)); - Class blasWrapperClazz = (Class) Class - .forName(pp.toString(BLAS_OPS)); + Class blasWrapperClazz = ND4JClassLoading.loadClassByName(pp.toString(BLAS_OPS)); String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName()); - Class distributionFactoryClazz = (Class) Class.forName(clazzName); + Class distributionFactoryClazz = ND4JClassLoading.loadClassByName(clazzName); memoryManager = memoryManagerClazz.newInstance(); @@ -5173,8 +5173,8 @@ public class Nd4j { shapeInfoProvider = shapeInfoProviderClazz.newInstance(); workspaceManager = workspaceManagerClazz.newInstance(); - Class opExecutionerClazz = (Class) Class - .forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName())); + Class opExecutionerClazz = ND4JClassLoading + .loadClassByName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName())); OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance(); Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class); @@ -5197,7 +5197,7 @@ public class Nd4j { OP_EXECUTIONER_INSTANCE.printEnvironmentInformation(); } - val actions = ServiceLoader.load(EnvironmentalAction.class); + val actions = ND4JClassLoading.loadService(EnvironmentalAction.class); val mappedActions = new HashMap(); for (val a: actions) { if (!mappedActions.containsKey(a.targetVariable())) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index c25f89a95..d7d69de57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.factory; import lombok.extern.slf4j.Slf4j; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.common.config.ND4JEnvironmentVars; import org.nd4j.common.config.ND4JSystemProperties; import org.nd4j.context.Nd4jContext; @@ -25,6 +26,7 @@ import org.nd4j.common.io.Resource; import java.io.File; import java.io.IOException; +import java.net.URLClassLoader; import java.security.PrivilegedActionException; import java.util.*; @@ -156,14 +158,12 @@ public abstract class Nd4jBackend { String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); boolean logInit = Boolean.parseBoolean(logInitProperty); - List backends = new ArrayList<>(1); - ServiceLoader loader = ServiceLoader.load(Nd4jBackend.class); + List backends = new ArrayList<>(); + ServiceLoader loader = ND4JClassLoading.loadService(Nd4jBackend.class); try { - - Iterator backendIterator = loader.iterator(); - while (backendIterator.hasNext()) - backends.add(backendIterator.next()); - + for (Nd4jBackend nd4jBackend : loader) { + backends.add(nd4jBackend); + } } catch (ServiceConfigurationError serviceError) { // a fatal error due to a syntax or provider construction error. // backends mustn't throw an exception during construction. @@ -240,7 +240,7 @@ public abstract class Nd4jBackend { public static synchronized void loadLibrary(File jar) throws NoAvailableBackendException { try { /*We are using reflection here to circumvent encapsulation; addURL is not public*/ - java.net.URLClassLoader loader = (java.net.URLClassLoader) ClassLoader.getSystemClassLoader(); + java.net.URLClassLoader loader = (URLClassLoader) ND4JClassLoading.getNd4jClassloader(); java.net.URL url = jar.toURI().toURL(); /*Disallow if already loaded*/ for (java.net.URL it : java.util.Arrays.asList(loader.getURLs())) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/BaseLegacyDeserializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/BaseLegacyDeserializer.java index 0a11c9489..b7a358917 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/BaseLegacyDeserializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/BaseLegacyDeserializer.java @@ -17,6 +17,7 @@ package org.nd4j.serde.json; import lombok.extern.slf4j.Slf4j; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.shade.jackson.core.JsonParser; import org.nd4j.shade.jackson.databind.DeserializationContext; import org.nd4j.shade.jackson.databind.JsonDeserializer; @@ -28,6 +29,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; /** * A base deserialization class used to handle deserializing of a specific class given changes from subtype wrapper @@ -79,13 +81,9 @@ public abstract class BaseLegacyDeserializer extends JsonDeserializer { + "\": legacy class mapping with this name is unknown"); } - Class lClass; - try { - lClass = (Class) Class.forName(layerClass); - } catch (Exception e){ - throw new RuntimeException("Could not find class for deserialization of \"" + name + "\" of type " + - getDeserializedType() + ": class " + layerClass + " is not on the classpath?", e); - } + Class lClass = ND4JClassLoading.loadClassByName(layerClass); + Objects.requireNonNull(lClass, "Could not find class for deserialization of \"" + name + "\" of type " + + getDeserializedType() + ": class " + layerClass + " is not on the classpath?"); ObjectMapper m = getLegacyJsonMapper(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java index db4cec0b2..e2948ea96 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java @@ -35,6 +35,7 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.SystemUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.bytedeco.javacpp.Pointer; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.api.environment.Nd4jEnvironment; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.factory.Nd4j; @@ -212,22 +213,22 @@ public class SystemInfo { boolean hasGPUs = false; - ServiceLoader loader = ServiceLoader.load(GPUInfoProvider.class); + ServiceLoader loader = ND4JClassLoading.loadService(GPUInfoProvider.class); Iterator iter = loader.iterator(); - if(iter.hasNext()){ + if (iter.hasNext()) { List gpus = iter.next().getGPUs(); sb.append(f("Number of GPUs Detected", gpus.size())); - if(!gpus.isEmpty()) + if (!gpus.isEmpty()) { hasGPUs = true; + } sb.append(String.format(fGpu, "Name", "CC", "Total Memory", "Used Memory", "Free Memory")).append("\n"); - for(GPUInfo gpuInfo : gpus){ + for (GPUInfo gpuInfo : gpus) { sb.append(gpuInfo).append("\n"); } - } else { sb.append("GPU Provider not found (are you missing nd4j-native?)"); } @@ -327,28 +328,24 @@ public class SystemInfo { appendProperty(sb, "Library Path", "java.library.path"); - - //classpath appendHeader(sb, "Classpath"); - ClassLoader cl = ClassLoader.getSystemClassLoader(); - URL[] urls = null; - try{ - urls = ((URLClassLoader)cl).getURLs(); - } catch (ClassCastException e){ - try { - urls = ((URLClassLoader) SystemInfo.class.getClassLoader()).getURLs(); - } catch (ClassCastException e1){ - try{ - urls = ((URLClassLoader) (Thread.currentThread().getContextClassLoader())).getURLs(); - } catch (ClassCastException e2) { - sb.append("Can't cast class loader to URLClassLoader\n"); - } - } + URLClassLoader urlClassLoader = null; + + if (ND4JClassLoading.getNd4jClassloader() instanceof URLClassLoader) { + urlClassLoader = (URLClassLoader) ND4JClassLoading.getNd4jClassloader(); + } else if (ClassLoader.getSystemClassLoader() instanceof URLClassLoader) { + urlClassLoader = (URLClassLoader) ClassLoader.getSystemClassLoader(); + } else if (SystemInfo.class.getClassLoader() instanceof URLClassLoader) { + urlClassLoader = (URLClassLoader) SystemInfo.class.getClassLoader(); + } else if (Thread.currentThread().getContextClassLoader() instanceof URLClassLoader) { + urlClassLoader = (URLClassLoader) Thread.currentThread().getContextClassLoader(); + } else { + sb.append("Can't cast class loader to URLClassLoader\n"); } - if(urls != null) { - for (URL url : urls) { + if (urlClassLoader != null) { + for (URL url : urlClassLoader.getURLs()) { sb.append(url.getFile()).append("\n"); } } else { @@ -359,7 +356,6 @@ public class SystemInfo { } } - //launch command appendHeader(sb, "Launch Command"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java index 28853e75a..93d6ce561 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java @@ -17,14 +17,27 @@ package org.nd4j.versioncheck; import lombok.extern.slf4j.Slf4j; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.common.config.ND4JSystemProperties; import java.io.IOException; import java.net.URI; import java.net.URL; -import java.nio.file.*; +import java.nio.file.FileSystem; +import java.nio.file.FileSystems; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.SimpleFileVisitor; import java.nio.file.attribute.BasicFileAttributes; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.List; +import java.util.Set; /** * A runtime version check utility that does 2 things:
@@ -92,14 +105,14 @@ public class VersionCheck { return; } - if(classExists(ND4J_JBLAS_CLASS)) { + if(ND4JClassLoading.classPresentOnClasspath(ND4J_JBLAS_CLASS)) { //nd4j-jblas is ancient and incompatible log.error("Found incompatible/obsolete backend and version (nd4j-jblas) on classpath. ND4J is unlikely to" + " function correctly with nd4j-jblas on the classpath. JVM will now exit."); System.exit(1); } - if(classExists(CANOVA_CLASS)) { + if(ND4JClassLoading.classPresentOnClasspath(CANOVA_CLASS)) { //Canova is ancient and likely to pull in incompatible dependencies log.error("Found incompatible/obsolete library Canova on classpath. ND4J is unlikely to" + " function correctly with this library on the classpath. JVM will now exit."); @@ -281,13 +294,13 @@ public class VersionCheck { } } - if(classExists(ND4J_JBLAS_CLASS)){ + if(ND4JClassLoading.classPresentOnClasspath(ND4J_JBLAS_CLASS)){ //nd4j-jblas is ancient and incompatible log.error("Found incompatible/obsolete backend and version (nd4j-jblas) on classpath. ND4J is unlikely to" + " function correctly with nd4j-jblas on the classpath."); } - if(classExists(CANOVA_CLASS)){ + if(ND4JClassLoading.classPresentOnClasspath(CANOVA_CLASS)){ //Canova is anchient and likely to pull in incompatible log.error("Found incompatible/obsolete library Canova on classpath. ND4J is unlikely to" + " function correctly with this library on the classpath."); @@ -296,16 +309,6 @@ public class VersionCheck { return repState; } - private static boolean classExists(String className){ - try{ - Class.forName(className); - return true; - } catch (ClassNotFoundException e ){ - //OK - not found - } - return false; - } - /** * @return A string representation of the version information, with the default (GAV) detail level */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java index 32f316ca0..5fad1d18e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java @@ -19,8 +19,10 @@ package org.nd4j.nativeblas; import java.util.Properties; import lombok.Getter; import org.bytedeco.javacpp.Loader; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.common.config.ND4JEnvironmentVars; import org.nd4j.common.config.ND4JSystemProperties; +import org.nd4j.common.io.ReflectionUtils; import org.nd4j.context.Nd4jContext; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; @@ -82,8 +84,10 @@ public class NativeOpsHolder { Properties props = Nd4jContext.getInstance().getConf(); String name = System.getProperty(Nd4j.NATIVE_OPS, props.get(Nd4j.NATIVE_OPS).toString()); - Class nativeOpsClazz = Class.forName(name).asSubclass(NativeOps.class); - deviceNativeOps = nativeOpsClazz.newInstance(); + Class nativeOpsClass = ND4JClassLoading + .loadClassByName(name) + .asSubclass(NativeOps.class); + deviceNativeOps = ReflectionUtils.newInstance(nativeOpsClass); deviceNativeOps.initializeDevicesAndFunctions(); int numThreads; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java index 4568a0a4b..056ff4fec 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java @@ -16,18 +16,18 @@ package org.nd4j.linalg; - import lombok.extern.slf4j.Slf4j; import org.junit.Before; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.nd4j.common.config.ND4JClassLoading; +import org.nd4j.common.io.ReflectionUtils; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.*; - /** * Base Nd4j test * @author Adam Gibson @@ -35,6 +35,18 @@ import java.util.*; @RunWith(Parameterized.class) @Slf4j public abstract class BaseNd4jTest extends BaseND4JTest { + private static List BACKENDS = new ArrayList<>(); + static { + List backendsToRun = Nd4jTestSuite.backendsToRun(); + + ServiceLoader loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class); + for (Nd4jBackend backend : loadedBackends) { + if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) + || backendsToRun.isEmpty()) { + BACKENDS.add(backend); + } + } + } protected Nd4jBackend backend; protected String name; @@ -57,24 +69,10 @@ public abstract class BaseNd4jTest extends BaseND4JTest { this(backend.getClass().getName() + UUID.randomUUID().toString(), backend); } - private static List backends; - static { - ServiceLoader loadedBackends = ServiceLoader.load(Nd4jBackend.class); - Iterator backendIterator = loadedBackends.iterator(); - backends = new ArrayList<>(); - List backendsToRun = Nd4jTestSuite.backendsToRun(); - - while (backendIterator.hasNext()) { - Nd4jBackend backend = backendIterator.next(); - if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty()) - backends.add(backend); - } - } - @Parameterized.Parameters(name = "{index}: backend({0})={1}") public static Collection configs() { List ret = new ArrayList<>(); - for (Nd4jBackend backend : backends) + for (Nd4jBackend backend : BACKENDS) ret.add(new Object[] {backend}); return ret; } @@ -93,16 +91,11 @@ public abstract class BaseNd4jTest extends BaseND4JTest { */ public static Nd4jBackend getDefaultBackend() { String cpuBackend = "org.nd4j.linalg.cpu.nativecpu.CpuBackend"; - //String cpuBackend = "org.nd4j.linalg.cpu.CpuBackend"; - String gpuBackend = "org.nd4j.linalg.jcublas.JCublasBackend"; - String clazz = System.getProperty(DEFAULT_BACKEND, cpuBackend); - try { - return (Nd4jBackend) Class.forName(clazz).newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } + String defaultBackendClass = System.getProperty(DEFAULT_BACKEND, cpuBackend); + Class backendClass = ND4JClassLoading.loadClassByName(defaultBackendClass); + return ReflectionUtils.newInstance(backendClass); + } /** * The ordering for this test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java index 72f4ea0e9..415ef64a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java @@ -17,10 +17,10 @@ package org.nd4j.linalg; import org.junit.runners.BlockJUnit4ClassRunner; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.ServiceLoader; @@ -32,19 +32,15 @@ import java.util.ServiceLoader; * * @author Adam Gibson */ - public class Nd4jTestSuite extends BlockJUnit4ClassRunner { //the system property for what backends should run public final static String BACKENDS_TO_LOAD = "backends"; - private static List backends; + private static List BACKENDS; static { - ServiceLoader loadedBackends = ServiceLoader.load(Nd4jBackend.class); - Iterator backendIterator = loadedBackends.iterator(); - backends = new ArrayList<>(); - while (backendIterator.hasNext()) - backends.add(backendIterator.next()); - - + ServiceLoader loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class); + for (Nd4jBackend backend : loadedBackends) { + BACKENDS.add(backend); + } } /** @@ -56,7 +52,6 @@ public class Nd4jTestSuite extends BlockJUnit4ClassRunner { super(klass); } - /** * Based on the jvm arguments, an empty list is returned * if all backends should be run. diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java index 7019c5fcc..ef6a0a97a 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java @@ -16,6 +16,8 @@ package org.nd4j.common.base; +import org.nd4j.common.config.ND4JClassLoading; + import java.util.*; /** @@ -23,24 +25,20 @@ import java.util.*; * * @author Alex Black */ -public class Preconditions { - - private static final Map formatters = new HashMap<>(); - +public final class Preconditions { + private static final Map FORMATTERS = new HashMap<>(); static { - ServiceLoader sl = ServiceLoader.load(PreconditionsFormat.class); - Iterator iter = sl.iterator(); - while(iter.hasNext()){ - PreconditionsFormat pf = iter.next(); + ServiceLoader sl = ND4JClassLoading.loadService(PreconditionsFormat.class); + for (PreconditionsFormat pf : sl) { List formatTags = pf.formatTags(); for(String s : formatTags){ - formatters.put(s, pf); + FORMATTERS.put(s, pf); } } - } - private Preconditions(){ } + private Preconditions() { + } /** * Check the specified boolean argument. Throws an IllegalArgumentException if {@code b} is false @@ -664,7 +662,7 @@ public class Preconditions { int nextCustom = -1; String nextCustomTag = null; - for(String s : formatters.keySet()){ + for(String s : FORMATTERS.keySet()){ int idxThis = message.indexOf(s, indexOfStart); if(idxThis > 0 && (nextCustom < 0 || idxThis < nextCustom)){ nextCustom = idxThis; @@ -696,7 +694,7 @@ public class Preconditions { } else { //Custom tag sb.append(message.substring(indexOfStart, nextCustom)); - String s = formatters.get(nextCustomTag).format(nextCustomTag, args[i]); + String s = FORMATTERS.get(nextCustomTag).format(nextCustomTag, args[i]); sb.append(s); indexOfStart = nextCustom + nextCustomTag.length(); } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java new file mode 100644 index 000000000..56d95642d --- /dev/null +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java @@ -0,0 +1,61 @@ +package org.nd4j.common.config; + +import lombok.extern.slf4j.Slf4j; + +import java.util.ServiceLoader; + +/** + * Global context for class-loading in ND4J. + * + * @author Alexei KLENIN + */ +@Slf4j +public final class ND4JClassLoading { + private static ClassLoader nd4jClassloader = Thread.currentThread().getContextClassLoader(); + + private ND4JClassLoading() { + } + + public static ClassLoader getNd4jClassloader() { + return ND4JClassLoading.nd4jClassloader; + } + + public static void setNd4jClassloaderFromClass(Class clazz) { + setNd4jClassloader(clazz.getClassLoader()); + } + + public static void setNd4jClassloader(ClassLoader nd4jClassloader) { + ND4JClassLoading.nd4jClassloader = nd4jClassloader; + log.debug("Global class-loader for ND4J was changed."); + } + + public static boolean classPresentOnClasspath(String className) { + return classPresentOnClasspath(className, nd4jClassloader); + } + + public static boolean classPresentOnClasspath(String className, ClassLoader classLoader) { + return loadClassByName(className, false, classLoader) != null; + } + + public static Class loadClassByName(String className) { + return loadClassByName(className, true, nd4jClassloader); + } + + @SuppressWarnings("unchecked") + public static Class loadClassByName(String className, boolean initialize, ClassLoader classLoader) { + try { + return (Class) Class.forName(className, initialize, classLoader); + } catch (ClassNotFoundException classNotFoundException) { + log.error(String.format("Cannot find class [%s] of provided class-loader.", className)); + return null; + } + } + + public static ServiceLoader loadService(Class serviceClass) { + return loadService(serviceClass, nd4jClassloader); + } + + public static ServiceLoader loadService(Class serviceClass, ClassLoader classLoader) { + return ServiceLoader.load(serviceClass, classLoader); + } +} diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java index 93c37051d..ce44a7991 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java @@ -20,6 +20,7 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.IOUtils; import org.nd4j.common.base.Preconditions; +import org.nd4j.common.config.ND4JClassLoading; import java.io.*; import java.net.MalformedURLException; @@ -55,7 +56,7 @@ public class ClassPathResource extends AbstractFileResolvingResource { } this.path = pathToUse; - this.classLoader = classLoader != null ? classLoader : ClassUtils.getDefaultClassLoader(); + this.classLoader = classLoader != null ? classLoader : ND4JClassLoading.getNd4jClassloader(); } public ClassPathResource(String path, Class clazz) { @@ -283,7 +284,7 @@ public class ClassPathResource extends AbstractFileResolvingResource { StringBuilder builder = new StringBuilder("class path resource ["); String pathToUse = this.path; if (this.clazz != null && !pathToUse.startsWith("/")) { - builder.append(ClassUtils.classPackageAsResourcePath(this.clazz)); + builder.append(ResourceUtils.classPackageAsResourcePath(this.clazz)); builder.append('/'); } @@ -320,7 +321,7 @@ public class ClassPathResource extends AbstractFileResolvingResource { private URL getUrl() { ClassLoader loader = null; try { - loader = Thread.currentThread().getContextClassLoader(); + loader = ND4JClassLoading.getNd4jClassloader(); } catch (Exception e) { // do nothing } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassUtils.java deleted file mode 100644 index 8ea3d2e0c..000000000 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ClassUtils.java +++ /dev/null @@ -1,710 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.common.io; - - -import java.beans.Introspector; -import java.lang.reflect.*; -import java.security.AccessControlException; -import java.util.*; - - -public abstract class ClassUtils { - public static final String ARRAY_SUFFIX = "[]"; - private static final String INTERNAL_ARRAY_PREFIX = "["; - private static final String NON_PRIMITIVE_ARRAY_PREFIX = "[L"; - private static final char PACKAGE_SEPARATOR = '.'; - private static final char INNER_CLASS_SEPARATOR = '$'; - public static final String CGLIB_CLASS_SEPARATOR = "$$"; - public static final String CLASS_FILE_SUFFIX = ".class"; - private static final Map, Class> primitiveWrapperTypeMap = new HashMap(8); - private static final Map primitiveTypeToWrapperMap = new HashMap(8); - private static final Map> primitiveTypeNameMap = new HashMap(32); - private static final Map> commonClassCache = new HashMap(32); - - public ClassUtils() {} - - private static void registerCommonClasses(Class... commonClasses) { - Class[] arr$ = commonClasses; - int len$ = commonClasses.length; - - for (int i$ = 0; i$ < len$; ++i$) { - Class clazz = arr$[i$]; - commonClassCache.put(clazz.getName(), clazz); - } - - } - - public static ClassLoader getDefaultClassLoader() { - ClassLoader cl = null; - - try { - cl = Thread.currentThread().getContextClassLoader(); - } catch (Throwable var2) { - ; - } - - if (cl == null) { - cl = ClassUtils.class.getClassLoader(); - } - - return cl; - } - - public static ClassLoader overrideThreadContextClassLoader(ClassLoader classLoaderToUse) { - Thread currentThread = Thread.currentThread(); - ClassLoader threadContextClassLoader = currentThread.getContextClassLoader(); - if (classLoaderToUse != null && !classLoaderToUse.equals(threadContextClassLoader)) { - currentThread.setContextClassLoader(classLoaderToUse); - return threadContextClassLoader; - } else { - return null; - } - } - - /** @deprecated */ - @Deprecated - public static Class forName(String name) throws ClassNotFoundException, LinkageError { - return forName(name, getDefaultClassLoader()); - } - - public static Class forName(String name, ClassLoader classLoader) throws ClassNotFoundException, LinkageError { - Assert.notNull(name, "Name must not be null"); - Class clazz = resolvePrimitiveClassName(name); - if (clazz == null) { - clazz = commonClassCache.get(name); - } - - if (clazz != null) { - return clazz; - } else { - Class ex; - String classLoaderToUse1; - if (name.endsWith("[]")) { - classLoaderToUse1 = name.substring(0, name.length() - "[]".length()); - ex = forName(classLoaderToUse1, classLoader); - return Array.newInstance(ex, 0).getClass(); - } else if (name.startsWith("[L") && name.endsWith(";")) { - classLoaderToUse1 = name.substring("[L".length(), name.length() - 1); - ex = forName(classLoaderToUse1, classLoader); - return Array.newInstance(ex, 0).getClass(); - } else if (name.startsWith("[")) { - classLoaderToUse1 = name.substring("[".length()); - ex = forName(classLoaderToUse1, classLoader); - return Array.newInstance(ex, 0).getClass(); - } else { - ClassLoader classLoaderToUse = classLoader; - if (classLoader == null) { - classLoaderToUse = getDefaultClassLoader(); - } - - try { - return classLoaderToUse.loadClass(name); - } catch (ClassNotFoundException var9) { - int lastDotIndex = name.lastIndexOf(46); - if (lastDotIndex != -1) { - String innerClassName = - name.substring(0, lastDotIndex) + '$' + name.substring(lastDotIndex + 1); - - try { - return classLoaderToUse.loadClass(innerClassName); - } catch (ClassNotFoundException var8) { - ; - } - } - - throw var9; - } - } - } - } - - public static Class resolveClassName(String className, ClassLoader classLoader) throws IllegalArgumentException { - try { - return forName(className, classLoader); - } catch (ClassNotFoundException var3) { - throw new IllegalArgumentException("Cannot find class [" + className + "]", var3); - } catch (LinkageError var4) { - throw new IllegalArgumentException( - "Error loading class [" + className + "]: problem with class file or dependent class.", - var4); - } - } - - public static Class resolvePrimitiveClassName(String name) { - Class result = null; - if (name != null && name.length() <= 8) { - result = primitiveTypeNameMap.get(name); - } - - return result; - } - - /** @deprecated */ - @Deprecated - public static boolean isPresent(String className) { - return isPresent(className, getDefaultClassLoader()); - } - - public static boolean isPresent(String className, ClassLoader classLoader) { - try { - forName(className, classLoader); - return true; - } catch (Throwable var3) { - return false; - } - } - - public static Class getUserClass(Object instance) { - Assert.notNull(instance, "Instance must not be null"); - return getUserClass((Class) instance.getClass()); - } - - public static Class getUserClass(Class clazz) { - if (clazz != null && clazz.getName().contains("$$")) { - Class superClass = clazz.getSuperclass(); - if (superClass != null && !Object.class.equals(superClass)) { - return superClass; - } - } - - return clazz; - } - - public static boolean isCacheSafe(Class clazz, ClassLoader classLoader) { - Assert.notNull(clazz, "Class must not be null"); - ClassLoader target = clazz.getClassLoader(); - if (target == null) { - return false; - } else { - ClassLoader cur = classLoader; - if (classLoader == target) { - return true; - } else { - do { - if (cur == null) { - return false; - } - - cur = cur.getParent(); - } while (cur != target); - - return true; - } - } - } - - public static String getShortName(String className) { - Assert.hasLength(className, "Class name must not be empty"); - int lastDotIndex = className.lastIndexOf(46); - int nameEndIndex = className.indexOf("$$"); - if (nameEndIndex == -1) { - nameEndIndex = className.length(); - } - - String shortName = className.substring(lastDotIndex + 1, nameEndIndex); - shortName = shortName.replace('$', '.'); - return shortName; - } - - public static String getShortName(Class clazz) { - return getShortName(getQualifiedName(clazz)); - } - - public static String getShortNameAsProperty(Class clazz) { - String shortName = getShortName((Class) clazz); - int dotIndex = shortName.lastIndexOf(46); - shortName = dotIndex != -1 ? shortName.substring(dotIndex + 1) : shortName; - return Introspector.decapitalize(shortName); - } - - public static String getClassFileName(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - String className = clazz.getName(); - int lastDotIndex = className.lastIndexOf(46); - return className.substring(lastDotIndex + 1) + ".class"; - } - - public static String getPackageName(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - return getPackageName(clazz.getName()); - } - - public static String getPackageName(String fqClassName) { - Assert.notNull(fqClassName, "Class name must not be null"); - int lastDotIndex = fqClassName.lastIndexOf(46); - return lastDotIndex != -1 ? fqClassName.substring(0, lastDotIndex) : ""; - } - - public static String getQualifiedName(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - return clazz.isArray() ? getQualifiedNameForArray(clazz) : clazz.getName(); - } - - private static String getQualifiedNameForArray(Class clazz) { - StringBuilder result = new StringBuilder(); - - while (clazz.isArray()) { - clazz = clazz.getComponentType(); - result.append("[]"); - } - - result.insert(0, clazz.getName()); - return result.toString(); - } - - public static String getQualifiedMethodName(Method method) { - Assert.notNull(method, "Method must not be null"); - return method.getDeclaringClass().getName() + "." + method.getName(); - } - - public static String getDescriptiveType(Object value) { - if (value == null) { - return null; - } else { - Class clazz = value.getClass(); - if (Proxy.isProxyClass(clazz)) { - StringBuilder result = new StringBuilder(clazz.getName()); - result.append(" implementing "); - Class[] ifcs = clazz.getInterfaces(); - - for (int i = 0; i < ifcs.length; ++i) { - result.append(ifcs[i].getName()); - if (i < ifcs.length - 1) { - result.append(','); - } - } - - return result.toString(); - } else { - return clazz.isArray() ? getQualifiedNameForArray(clazz) : clazz.getName(); - } - } - } - - public static boolean matchesTypeName(Class clazz, String typeName) { - return typeName != null && (typeName.equals(clazz.getName()) || typeName.equals(clazz.getSimpleName()) - || clazz.isArray() && typeName.equals(getQualifiedNameForArray(clazz))); - } - - public static boolean hasConstructor(Class clazz, Class... paramTypes) { - return getConstructorIfAvailable(clazz, paramTypes) != null; - } - - public static Constructor getConstructorIfAvailable(Class clazz, Class... paramTypes) { - Assert.notNull(clazz, "Class must not be null"); - - try { - return clazz.getConstructor(paramTypes); - } catch (NoSuchMethodException var3) { - return null; - } - } - - public static boolean hasMethod(Class clazz, String methodName, Class... paramTypes) { - return getMethodIfAvailable(clazz, methodName, paramTypes) != null; - } - - public static Method getMethod(Class clazz, String methodName, Class... paramTypes) { - Assert.notNull(clazz, "Class must not be null"); - Assert.notNull(methodName, "Method name must not be null"); - if (paramTypes != null) { - try { - return clazz.getMethod(methodName, paramTypes); - } catch (NoSuchMethodException var9) { - throw new IllegalStateException("Expected method not found: " + var9); - } - } else { - HashSet candidates = new HashSet(1); - Method[] methods = clazz.getMethods(); - Method[] arr$ = methods; - int len$ = methods.length; - - for (int i$ = 0; i$ < len$; ++i$) { - Method method = arr$[i$]; - if (methodName.equals(method.getName())) { - candidates.add(method); - } - } - - if (candidates.size() == 1) { - return (Method) candidates.iterator().next(); - } else if (candidates.isEmpty()) { - throw new IllegalStateException("Expected method not found: " + clazz + "." + methodName); - } else { - throw new IllegalStateException("No unique method found: " + clazz + "." + methodName); - } - } - } - - public static Method getMethodIfAvailable(Class clazz, String methodName, Class... paramTypes) { - Assert.notNull(clazz, "Class must not be null"); - Assert.notNull(methodName, "Method name must not be null"); - if (paramTypes != null) { - try { - return clazz.getMethod(methodName, paramTypes); - } catch (NoSuchMethodException var9) { - return null; - } - } else { - HashSet candidates = new HashSet(1); - Method[] methods = clazz.getMethods(); - Method[] arr$ = methods; - int len$ = methods.length; - - for (int i$ = 0; i$ < len$; ++i$) { - Method method = arr$[i$]; - if (methodName.equals(method.getName())) { - candidates.add(method); - } - } - - if (candidates.size() == 1) { - return (Method) candidates.iterator().next(); - } else { - return null; - } - } - } - - public static int getMethodCountForName(Class clazz, String methodName) { - Assert.notNull(clazz, "Class must not be null"); - Assert.notNull(methodName, "Method name must not be null"); - int count = 0; - Method[] declaredMethods = clazz.getDeclaredMethods(); - Method[] ifcs = declaredMethods; - int arr$ = declaredMethods.length; - - int len$; - for (len$ = 0; len$ < arr$; ++len$) { - Method i$ = ifcs[len$]; - if (methodName.equals(i$.getName())) { - ++count; - } - } - - Class[] var9 = clazz.getInterfaces(); - Class[] var10 = var9; - len$ = var9.length; - - for (int var11 = 0; var11 < len$; ++var11) { - Class ifc = var10[var11]; - count += getMethodCountForName(ifc, methodName); - } - - if (clazz.getSuperclass() != null) { - count += getMethodCountForName(clazz.getSuperclass(), methodName); - } - - return count; - } - - public static boolean hasAtLeastOneMethodWithName(Class clazz, String methodName) { - Assert.notNull(clazz, "Class must not be null"); - Assert.notNull(methodName, "Method name must not be null"); - Method[] declaredMethods = clazz.getDeclaredMethods(); - Method[] ifcs = declaredMethods; - int arr$ = declaredMethods.length; - - int len$; - for (len$ = 0; len$ < arr$; ++len$) { - Method i$ = ifcs[len$]; - if (i$.getName().equals(methodName)) { - return true; - } - } - - Class[] var8 = clazz.getInterfaces(); - Class[] var9 = var8; - len$ = var8.length; - - for (int var10 = 0; var10 < len$; ++var10) { - Class ifc = var9[var10]; - if (hasAtLeastOneMethodWithName(ifc, methodName)) { - return true; - } - } - - return clazz.getSuperclass() != null && hasAtLeastOneMethodWithName(clazz.getSuperclass(), methodName); - } - - public static Method getMostSpecificMethod(Method method, Class targetClass) { - if (method != null && isOverridable(method, targetClass) && targetClass != null - && !targetClass.equals(method.getDeclaringClass())) { - try { - if (Modifier.isPublic(method.getModifiers())) { - try { - return targetClass.getMethod(method.getName(), method.getParameterTypes()); - } catch (NoSuchMethodException var3) { - return method; - } - } - - Method ex = ReflectionUtils.findMethod(targetClass, method.getName(), method.getParameterTypes()); - return ex != null ? ex : method; - } catch (AccessControlException var4) { - ; - } - } - - return method; - } - - private static boolean isOverridable(Method method, Class targetClass) { - return Modifier.isPrivate(method.getModifiers()) ? false - : (!Modifier.isPublic(method.getModifiers()) && !Modifier.isProtected(method.getModifiers()) - ? getPackageName((Class) method.getDeclaringClass()) - .equals(getPackageName(targetClass)) - : true); - } - - public static Method getStaticMethod(Class clazz, String methodName, Class... args) { - Assert.notNull(clazz, "Class must not be null"); - Assert.notNull(methodName, "Method name must not be null"); - - try { - Method ex = clazz.getMethod(methodName, args); - return Modifier.isStatic(ex.getModifiers()) ? ex : null; - } catch (NoSuchMethodException var4) { - return null; - } - } - - public static boolean isPrimitiveWrapper(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - return primitiveWrapperTypeMap.containsKey(clazz); - } - - public static boolean isPrimitiveOrWrapper(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - return clazz.isPrimitive() || isPrimitiveWrapper(clazz); - } - - public static boolean isPrimitiveArray(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - return clazz.isArray() && clazz.getComponentType().isPrimitive(); - } - - public static boolean isPrimitiveWrapperArray(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - return clazz.isArray() && isPrimitiveWrapper(clazz.getComponentType()); - } - - public static Class resolvePrimitiveIfNecessary(Class clazz) { - Assert.notNull(clazz, "Class must not be null"); - return clazz.isPrimitive() && clazz != Void.TYPE ? (Class) primitiveTypeToWrapperMap.get(clazz) : clazz; - } - - public static boolean isAssignable(Class lhsType, Class rhsType) { - Assert.notNull(lhsType, "Left-hand side opType must not be null"); - Assert.notNull(rhsType, "Right-hand side opType must not be null"); - if (lhsType.isAssignableFrom(rhsType)) { - return true; - } else { - Class resolvedWrapper; - if (lhsType.isPrimitive()) { - resolvedWrapper = primitiveWrapperTypeMap.get(rhsType); - if (resolvedWrapper != null && lhsType.equals(resolvedWrapper)) { - return true; - } - } else { - resolvedWrapper = (Class) primitiveTypeToWrapperMap.get(rhsType); - if (resolvedWrapper != null && lhsType.isAssignableFrom(resolvedWrapper)) { - return true; - } - } - - return false; - } - } - - public static boolean isAssignableValue(Class type, Object value) { - Assert.notNull(type, "Type must not be null"); - return value != null ? isAssignable(type, value.getClass()) : !type.isPrimitive(); - } - - public static String convertResourcePathToClassName(String resourcePath) { - Assert.notNull(resourcePath, "Resource path must not be null"); - return resourcePath.replace('/', '.'); - } - - public static String convertClassNameToResourcePath(String className) { - Assert.notNull(className, "Class name must not be null"); - return className.replace('.', '/'); - } - - public static String addResourcePathToPackagePath(Class clazz, String resourceName) { - Assert.notNull(resourceName, "Resource name must not be null"); - return !resourceName.startsWith("/") ? classPackageAsResourcePath(clazz) + "/" + resourceName - : classPackageAsResourcePath(clazz) + resourceName; - } - - public static String classPackageAsResourcePath(Class clazz) { - if (clazz == null) { - return ""; - } else { - String className = clazz.getName(); - int packageEndIndex = className.lastIndexOf(46); - if (packageEndIndex == -1) { - return ""; - } else { - String packageName = className.substring(0, packageEndIndex); - return packageName.replace('.', '/'); - } - } - } - - public static String classNamesToString(Class... classes) { - return classNamesToString((Collection) Arrays.asList(classes)); - } - - public static String classNamesToString(Collection classes) { - if (CollectionUtils.isEmpty(classes)) { - return "[]"; - } else { - StringBuilder sb = new StringBuilder("["); - Iterator it = classes.iterator(); - - while (it.hasNext()) { - Class clazz = (Class) it.next(); - sb.append(clazz.getName()); - if (it.hasNext()) { - sb.append(", "); - } - } - - sb.append("]"); - return sb.toString(); - } - } - - public static Class[] toClassArray(Collection> collection) { - return collection == null ? null : collection.toArray(new Class[collection.size()]); - } - - public static Class[] getAllInterfaces(Object instance) { - Assert.notNull(instance, "Instance must not be null"); - return getAllInterfacesForClass(instance.getClass()); - } - - public static Class[] getAllInterfacesForClass(Class clazz) { - return getAllInterfacesForClass(clazz, null); - } - - public static Class[] getAllInterfacesForClass(Class clazz, ClassLoader classLoader) { - Set ifcs = getAllInterfacesForClassAsSet(clazz, classLoader); - return (Class[]) ifcs.toArray(new Class[ifcs.size()]); - } - - public static Set getAllInterfacesAsSet(Object instance) { - Assert.notNull(instance, "Instance must not be null"); - return getAllInterfacesForClassAsSet(instance.getClass()); - } - - public static Set getAllInterfacesForClassAsSet(Class clazz) { - return getAllInterfacesForClassAsSet(clazz, null); - } - - public static Set getAllInterfacesForClassAsSet(Class clazz, ClassLoader classLoader) { - Assert.notNull(clazz, "Class must not be null"); - if (clazz.isInterface() && isVisible(clazz, classLoader)) { - return Collections.singleton(clazz); - } else { - LinkedHashSet interfaces; - for (interfaces = new LinkedHashSet(); clazz != null; clazz = clazz.getSuperclass()) { - Class[] ifcs = clazz.getInterfaces(); - Class[] arr$ = ifcs; - int len$ = ifcs.length; - - for (int i$ = 0; i$ < len$; ++i$) { - Class ifc = arr$[i$]; - interfaces.addAll(getAllInterfacesForClassAsSet(ifc, classLoader)); - } - } - - return interfaces; - } - } - - public static Class createCompositeInterface(Class[] interfaces, ClassLoader classLoader) { - Assert.notEmpty(interfaces, "Interfaces must not be empty"); - Assert.notNull(classLoader, "ClassLoader must not be null"); - return Proxy.getProxyClass(classLoader, interfaces); - } - - public static boolean isVisible(Class clazz, ClassLoader classLoader) { - if (classLoader == null) { - return true; - } else { - try { - Class ex = classLoader.loadClass(clazz.getName()); - return clazz == ex; - } catch (ClassNotFoundException var3) { - return false; - } - } - } - - public static boolean isCglibProxy(Object object) { - return isCglibProxyClass(object.getClass()); - } - - public static boolean isCglibProxyClass(Class clazz) { - return clazz != null && isCglibProxyClassName(clazz.getName()); - } - - public static boolean isCglibProxyClassName(String className) { - return className != null && className.contains("$$"); - } - - static { - primitiveWrapperTypeMap.put(Boolean.class, Boolean.TYPE); - primitiveWrapperTypeMap.put(Byte.class, Byte.TYPE); - primitiveWrapperTypeMap.put(Character.class, Character.TYPE); - primitiveWrapperTypeMap.put(Double.class, Double.TYPE); - primitiveWrapperTypeMap.put(Float.class, Float.TYPE); - primitiveWrapperTypeMap.put(Integer.class, Integer.TYPE); - primitiveWrapperTypeMap.put(Long.class, Long.TYPE); - primitiveWrapperTypeMap.put(Short.class, Short.TYPE); - Iterator, Class>> primitiveTypes = primitiveWrapperTypeMap.entrySet().iterator(); - - while (primitiveTypes.hasNext()) { - Map.Entry i$ = primitiveTypes.next(); - primitiveTypeToWrapperMap.put(i$.getValue(), i$.getKey()); - registerCommonClasses(new Class[] {(Class) i$.getKey()}); - } - - HashSet primitiveTypes1 = new HashSet(32); - primitiveTypes1.addAll(primitiveWrapperTypeMap.values()); - primitiveTypes1.addAll(Arrays.asList(new Class[] {boolean[].class, byte[].class, char[].class, double[].class, - float[].class, int[].class, long[].class, short[].class})); - primitiveTypes1.add(Void.TYPE); - Iterator i$1 = primitiveTypes1.iterator(); - - while (i$1.hasNext()) { - Class primitiveType = (Class) i$1.next(); - primitiveTypeNameMap.put(primitiveType.getName(), primitiveType); - } - - registerCommonClasses(new Class[] {Boolean[].class, Byte[].class, Character[].class, Double[].class, - Float[].class, Integer[].class, Long[].class, Short[].class}); - registerCommonClasses(new Class[] {Number.class, Number[].class, String.class, String[].class, Object.class, - Object[].class, Class.class, Class[].class}); - registerCommonClasses(new Class[] {Throwable.class, Exception.class, RuntimeException.class, Error.class, - StackTraceElement.class, StackTraceElement[].class}); - } -} diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java index 4c78819b9..0371e24dd 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java @@ -21,6 +21,7 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; +import java.util.Objects; import java.util.regex.Pattern; public abstract class ReflectionUtils { @@ -398,6 +399,40 @@ public abstract class ReflectionUtils { } } + /** + * Create a new instance of the specified {@link Class} by invoking + * the constructor whose argument list matches the types of the supplied + * arguments. + * + *

Provided class must have a public constructor.

+ * + * @param clazz the class to instantiate; never {@code null} + * @param args the arguments to pass to the constructor, none of which may + * be {@code null} + * @return the new instance; never {@code null} + */ + public static T newInstance(Class clazz, Object... args) { + Objects.requireNonNull(clazz, "Class must not be null"); + Objects.requireNonNull(args, "Argument array must not be null"); + if (Arrays.asList(args).contains(null)) { + throw new RuntimeException("Individual arguments must not be null"); + } + + try { + Class[] parameterTypes = Arrays.stream(args).map(Object::getClass).toArray(Class[]::new); + Constructor constructor = clazz.getDeclaredConstructor(parameterTypes); + + if (!Modifier.isPublic(constructor.getModifiers())) { + throw new IllegalArgumentException(String.format( + "Class [%s] must have public constructor in order to be instantiated.", clazz.getName())); + } + + return constructor.newInstance(args); + } catch (Throwable instantiationException) { + throw new RuntimeException(instantiationException); + } + } + public interface FieldFilter { boolean matches(Field var1); } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ResourceUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ResourceUtils.java index 198da03cf..83d420008 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ResourceUtils.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/common/io/ResourceUtils.java @@ -16,9 +16,12 @@ package org.nd4j.common.io; +import org.nd4j.common.config.ND4JClassLoading; + import java.io.File; import java.io.FileNotFoundException; import java.net.*; +import java.util.Objects; public abstract class ResourceUtils { @@ -54,7 +57,7 @@ public abstract class ResourceUtils { Assert.notNull(resourceLocation, "Resource location must not be null"); if (resourceLocation.startsWith("classpath:")) { String ex = resourceLocation.substring("classpath:".length()); - URL ex2 = ClassUtils.getDefaultClassLoader().getResource(ex); + URL ex2 = ND4JClassLoading.getNd4jClassloader().getResource(ex); if (ex2 == null) { String description = "class path resource [" + ex + "]"; throw new FileNotFoundException(description + " cannot be resolved to URL because it does not exist"); @@ -80,7 +83,7 @@ public abstract class ResourceUtils { if (resourceLocation.startsWith("classpath:")) { String ex = resourceLocation.substring("classpath:".length()); String description = "class path resource [" + ex + "]"; - URL url = ClassUtils.getDefaultClassLoader().getResource(ex); + URL url = ND4JClassLoading.getNd4jClassloader().getResource(ex); if (url == null) { throw new FileNotFoundException(description + " cannot be resolved to absolute file path " + "because it does not reside in the file system"); @@ -170,4 +173,17 @@ public abstract class ResourceUtils { public static void useCachesIfNecessary(URLConnection con) { con.setUseCaches(con.getClass().getSimpleName().startsWith("JNLP")); } + + public static String classPackageAsResourcePath(Class clazz) { + Objects.requireNonNull(clazz); + + String className = clazz.getName(); + int packageEndIndex = className.lastIndexOf(46); + if (packageEndIndex == -1) { + return ""; + } else { + String packageName = className.substring(0, packageEndIndex); + return packageName.replace('.', '/'); + } + } } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java index 1ba4a842e..a58cca96c 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java @@ -2,6 +2,7 @@ package org.nd4j.common.resources; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.common.resources.strumpf.StrumpfResolver; import java.io.File; @@ -24,15 +25,12 @@ public class Resources { protected final List resolvers; protected Resources() { - - ServiceLoader loader = ServiceLoader.load(Resolver.class); - Iterator iter = loader.iterator(); + ServiceLoader loader = ND4JClassLoading.loadService(Resolver.class); resolvers = new ArrayList<>(); resolvers.add(new StrumpfResolver()); - while (iter.hasNext()) { - Resolver r = iter.next(); - resolvers.add(r); + for (Resolver resolver : loader) { + resolvers.add(resolver); } //Sort resolvers by priority: check resolvers with lower numbers first @@ -42,8 +40,6 @@ public class Resources { return Integer.compare(r1.priority(), r2.priority()); } }); - - } /** diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/common/util/ReflectionUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/common/util/ReflectionUtils.java deleted file mode 100644 index b8870d77f..000000000 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/common/util/ReflectionUtils.java +++ /dev/null @@ -1,122 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.common.util; - - -import java.io.PrintWriter; -import java.lang.management.ManagementFactory; -import java.lang.management.ThreadInfo; -import java.lang.management.ThreadMXBean; -import java.lang.reflect.Constructor; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * General reflection utils - */ -public class ReflectionUtils { - - /** - * Cache of constructors for each class. Pins the classes so they - * can't be garbage collected until ReflectionUtils can be collected. - */ - protected static final Map, Constructor> CONSTRUCTOR_CACHE = new ConcurrentHashMap<>(); - - protected ReflectionUtils() {} - - - - static private ThreadMXBean threadBean = ManagementFactory.getThreadMXBean(); - - public static void setContentionTracing(boolean val) { - threadBean.setThreadContentionMonitoringEnabled(val); - } - - private static String getTaskName(long id, String name) { - if (name == null) { - return Long.toString(id); - } - return id + " (" + name + ")"; - } - - /** - * Print all of the thread's information and stack traces. - * - * @param stream the stream to - * @param title a string title for the stack trace - */ - public static void printThreadInfo(PrintWriter stream, String title) { - final int STACK_DEPTH = 20; - boolean contention = threadBean.isThreadContentionMonitoringEnabled(); - long[] threadIds = threadBean.getAllThreadIds(); - stream.println("Process Thread Dump: " + title); - stream.println(threadIds.length + " active threads"); - for (long tid : threadIds) { - ThreadInfo info = threadBean.getThreadInfo(tid, STACK_DEPTH); - if (info == null) { - stream.println(" Inactive"); - continue; - } - stream.println("Thread " + getTaskName(info.getThreadId(), info.getThreadName()) + ":"); - Thread.State state = info.getThreadState(); - stream.println(" State: " + state); - stream.println(" Blocked count: " + info.getBlockedCount()); - stream.println(" Waited count: " + info.getWaitedCount()); - if (contention) { - stream.println(" Blocked time: " + info.getBlockedTime()); - stream.println(" Waited time: " + info.getWaitedTime()); - } - if (state == Thread.State.WAITING) { - stream.println(" Waiting on " + info.getLockName()); - } else if (state == Thread.State.BLOCKED) { - stream.println(" Blocked on " + info.getLockName()); - stream.println(" Blocked by " + getTaskName(info.getLockOwnerId(), info.getLockOwnerName())); - } - stream.println(" Stack:"); - for (StackTraceElement frame : info.getStackTrace()) { - stream.println(" " + frame.toString()); - } - } - stream.flush(); - } - - private static long previousLogTime = 0; - - - /** - * Return the correctly-typed {@link Class} of the given object. - * - * @param o object whose correctly-typed Class is to be obtained - * @return the correctly typed Class of the given object. - */ - @SuppressWarnings("unchecked") - public static Class getClass(T o) { - return (Class) o.getClass(); - } - - // methods to support testing - static void clearCache() { - CONSTRUCTOR_CACHE.clear(); - } - - static int getCacheSize() { - return CONSTRUCTOR_CACHE.size(); - } - - - -} diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java b/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java index 4fa9cc794..24a0b1ee8 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java @@ -17,11 +17,13 @@ package org.nd4j.jdbc.driverfinder; import lombok.extern.slf4j.Slf4j; +import org.nd4j.common.config.ND4JClassLoading; import java.io.IOException; import java.io.InputStream; import java.sql.Driver; import java.util.HashSet; +import java.util.Objects; import java.util.Properties; import java.util.ServiceLoader; import java.util.Set; @@ -45,22 +47,19 @@ public class DriverFinder { discoverDriverClazz(); try { driver = clazz.newInstance(); - } catch (InstantiationException e) { - log.error("",e); - } catch (IllegalAccessException e) { + } catch (InstantiationException | IllegalAccessException e) { log.error("",e); } } return driver; } - private static void discoverDriverClazz() { //All JDBC4 compliant drivers support ServiceLoader mechanism for discovery - https://stackoverflow.com/a/18297412 - ServiceLoader drivers = ServiceLoader.load(Driver.class); + ServiceLoader drivers = ND4JClassLoading.loadService(Driver.class); Set> driverClasses = new HashSet<>(); - for(Driver d : drivers){ - driverClasses.add(d.getClass()); + for(Driver driver : drivers){ + driverClasses.add(driver.getClass()); } if(driverClasses.isEmpty()){ @@ -79,16 +78,13 @@ public class DriverFinder { throw new RuntimeException(e); } - String clazz = props.getProperty(JDBC_KEY); - if (clazz == null) - throw new IllegalStateException("Unable to find jdbc driver. Please specify a " - + ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY); - try { - DriverFinder.clazz = (Class) Class.forName(clazz); - } catch (ClassNotFoundException e) { - throw new IllegalStateException("Unable to find jdbc driver. Please specify a " - + ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY); - } + String jdbcKeyClassName = props.getProperty(JDBC_KEY); + Objects.requireNonNull(jdbcKeyClassName, "Unable to find jdbc driver. Please specify a " + + ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY); + + DriverFinder.clazz = ND4JClassLoading.loadClassByName(jdbcKeyClassName); + Objects.requireNonNull(DriverFinder.clazz, "Unable to find jdbc driver. Please specify a " + + ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY); } } } diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java index b019a2748..96c9ac5e4 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java @@ -21,6 +21,7 @@ import org.hsqldb.jdbc.JDBCDataSource; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -46,13 +47,12 @@ public class HSqlLoaderTest extends BaseND4JTest { @BeforeClass public static void init() throws Exception { hsqlLoader = new HsqlLoader(dataSource(),JDBC_URL,TABLE_NAME,ID_COLUMN_NAME,COLUMN_NAME); - Class.forName("org.hsqldb.jdbc.JDBCDriver"); + ND4JClassLoading.loadClassByName("org.hsqldb.jdbc.JDBCDriver"); // initialize database initDatabase(); } - public static DataSource dataSource() { if (dataSource != null) return dataSource; @@ -65,8 +65,6 @@ public class HSqlLoaderTest extends BaseND4JTest { return dataSource; } - - @AfterClass public static void destroy() throws SQLException { try (Connection connection = getConnection(); Statement statement = connection.createStatement()) { @@ -131,6 +129,4 @@ public class HSqlLoaderTest extends BaseND4JTest { } - - } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/VoidMessage.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/VoidMessage.java index 1108ecfc8..2d2e3a509 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/VoidMessage.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/VoidMessage.java @@ -18,10 +18,11 @@ package org.nd4j.parameterserver.distributed.messages; import org.agrona.concurrent.UnsafeBuffer; import org.apache.commons.io.input.ClassLoaderObjectInputStream; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; -import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; import org.nd4j.parameterserver.distributed.logic.Storage; +import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; import org.nd4j.parameterserver.distributed.training.TrainingDriver; import org.nd4j.parameterserver.distributed.transport.Transport; @@ -51,17 +52,16 @@ public interface VoidMessage extends Serializable { UnsafeBuffer asUnsafeBuffer(); + @SuppressWarnings("unchecked") static T fromBytes(byte[] array) { - try { - ObjectInputStream in = new ClassLoaderObjectInputStream(Thread.currentThread().getContextClassLoader(), - new ByteArrayInputStream(array)); + ClassLoader classloader = ND4JClassLoading.getNd4jClassloader(); - T result = (T) in.readObject(); - return result; - } catch (Exception e) { - throw new RuntimeException(e); + try (ByteArrayInputStream bis = new ByteArrayInputStream(array); + ObjectInputStream ois = new ClassLoaderObjectInputStream(classloader, bis)) { + return (T) ois.readObject(); + } catch (Exception objectReadException) { + throw new RuntimeException(objectReadException); } - //return SerializationUtils.deserialize(array); } /** diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/TrainerProvider.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/TrainerProvider.java index 96be29c97..0e3b14619 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/TrainerProvider.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/TrainerProvider.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.distributed.training; import lombok.NonNull; +import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.logic.Storage; @@ -52,13 +53,14 @@ public class TrainerProvider { } protected void loadProviders(){ - ServiceLoader serviceLoader = ServiceLoader.load(TrainingDriver.class); - for(TrainingDriver d : serviceLoader){ - trainers.put(d.targetMessageClass(), d); + ServiceLoader serviceLoader = ND4JClassLoading.loadService(TrainingDriver.class); + for (TrainingDriver trainingDriver : serviceLoader){ + trainers.put(trainingDriver.targetMessageClass(), trainingDriver); } - if (trainers.size() < 1) + if (trainers.isEmpty()) { throw new ND4JIllegalStateException("No TrainingDrivers were found via ServiceLoader mechanism"); + } } public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, @@ -73,8 +75,6 @@ public class TrainerProvider { } } - - @SuppressWarnings("unchecked") protected TrainingDriver getTrainer(T message) { TrainingDriver driver = trainers.get(message.getClass().getSimpleName()); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index 2641369de..0ff344793 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -21,6 +21,8 @@ import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import com.beust.jcommander.Parameters; import lombok.extern.slf4j.Slf4j; +import org.nd4j.common.config.ND4JClassLoading; +import org.nd4j.common.io.ReflectionUtils; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -290,12 +292,10 @@ public class ParameterServerSubscriber implements AutoCloseable { case TIME_DELAYED: break; case CUSTOM: - try { - updater = (ParameterServerUpdater) Class.forName(System.getProperty(CUSTOM_UPDATE_TYPE)) - .newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } + String parameterServerUpdateType = System.getProperty(CUSTOM_UPDATE_TYPE); + Class updaterClass = ND4JClassLoading + .loadClassByName(parameterServerUpdateType); + updater = ReflectionUtils.newInstance(updaterClass); break; default: throw new IllegalStateException("Illegal opType of updater");