FEATURE: possibility to define global Classload for ND4J (#8972)
Signed-off-by: hosuaby <alexei.klenin@gmail.com>master
parent
f9aebec79e
commit
881a672fa1
|
@ -25,14 +25,13 @@ import org.datavec.api.io.serializers.SerializationFactory;
|
||||||
import org.datavec.api.io.serializers.Serializer;
|
import org.datavec.api.io.serializers.Serializer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.lang.reflect.Constructor;
|
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated Use {@link org.nd4j.common.util.ReflectionUtils}
|
* @deprecated Use {@link org.nd4j.common.io.ReflectionUtils}
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public class ReflectionUtils extends org.nd4j.common.util.ReflectionUtils {
|
public class ReflectionUtils {
|
||||||
|
|
||||||
private static final Class<?>[] EMPTY_ARRAY = new Class[] {};
|
private static final Class<?>[] EMPTY_ARRAY = new Class[] {};
|
||||||
private static SerializationFactory serialFactory = null;
|
private static SerializationFactory serialFactory = null;
|
||||||
|
@ -48,18 +47,7 @@ public class ReflectionUtils extends org.nd4j.common.util.ReflectionUtils {
|
||||||
*/
|
*/
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public static <T> T newInstance(Class<T> theClass, Configuration conf) {
|
public static <T> T newInstance(Class<T> theClass, Configuration conf) {
|
||||||
T result;
|
T result = org.nd4j.common.io.ReflectionUtils.newInstance(theClass);
|
||||||
try {
|
|
||||||
Constructor<T> meth = (Constructor<T>) CONSTRUCTOR_CACHE.get(theClass);
|
|
||||||
if (meth == null) {
|
|
||||||
meth = theClass.getDeclaredConstructor(EMPTY_ARRAY);
|
|
||||||
meth.setAccessible(true);
|
|
||||||
CONSTRUCTOR_CACHE.put(theClass, meth);
|
|
||||||
}
|
|
||||||
result = meth.newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
setConf(result, conf);
|
setConf(result, conf);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.validation;
|
package org.nd4j.autodiff.validation;
|
||||||
|
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.linalg.api.ops.custom.*;
|
import org.nd4j.linalg.api.ops.custom.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
|
@ -613,14 +614,8 @@ public class OpValidation {
|
||||||
allOps = new ArrayList<>(gradCheckCoverageCountPerClass.keySet());
|
allOps = new ArrayList<>(gradCheckCoverageCountPerClass.keySet());
|
||||||
for (ClassPath.ClassInfo c : info) {
|
for (ClassPath.ClassInfo c : info) {
|
||||||
//Load method: Loads (but doesn't link or initialize) the class.
|
//Load method: Loads (but doesn't link or initialize) the class.
|
||||||
Class<?> clazz;
|
Class<?> clazz = ND4JClassLoading.loadClassByName(c.getName());
|
||||||
try {
|
Objects.requireNonNull(clazz);
|
||||||
clazz = Class.forName(c.getName());
|
|
||||||
} catch (ClassNotFoundException e) {
|
|
||||||
//Should never happen as this was found on the classpath
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || !DifferentialFunction.class.isAssignableFrom(clazz))
|
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || !DifferentialFunction.class.isAssignableFrom(clazz))
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.compression;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -50,7 +51,7 @@ public class BasicNDArrayCompressor {
|
||||||
*/
|
*/
|
||||||
codecs = new ConcurrentHashMap<>();
|
codecs = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
ServiceLoader<NDArrayCompressor> loader = ServiceLoader.load(NDArrayCompressor.class);
|
ServiceLoader<NDArrayCompressor> loader = ND4JClassLoading.loadService(NDArrayCompressor.class);
|
||||||
for (NDArrayCompressor compressor : loader) {
|
for (NDArrayCompressor compressor : loader) {
|
||||||
codecs.put(compressor.getDescriptor().toUpperCase(), compressor);
|
codecs.put(compressor.getDescriptor().toUpperCase(), compressor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.dataset.api.preprocessor.serializer;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
|
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
@ -215,7 +216,7 @@ public class NormalizerSerializer {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
* @throws IllegalArgumentException if the data format is invalid
|
* @throws IllegalArgumentException if the data format is invalid
|
||||||
*/
|
*/
|
||||||
private Header parseHeader(InputStream stream) throws IOException, ClassNotFoundException {
|
private Header parseHeader(InputStream stream) throws IOException {
|
||||||
DataInputStream dis = new DataInputStream(stream);
|
DataInputStream dis = new DataInputStream(stream);
|
||||||
// Check if the stream starts with the expected header
|
// Check if the stream starts with the expected header
|
||||||
String header = dis.readUTF();
|
String header = dis.readUTF();
|
||||||
|
@ -237,8 +238,9 @@ public class NormalizerSerializer {
|
||||||
if (type.equals(NormalizerType.CUSTOM)) {
|
if (type.equals(NormalizerType.CUSTOM)) {
|
||||||
// For custom serializers, the next value is a string with the class opName
|
// For custom serializers, the next value is a string with the class opName
|
||||||
String strategyClassName = dis.readUTF();
|
String strategyClassName = dis.readUTF();
|
||||||
//noinspection unchecked
|
Class<? extends NormalizerSerializerStrategy> strategyClass = ND4JClassLoading
|
||||||
return new Header(type, (Class<? extends NormalizerSerializerStrategy>) Class.forName(strategyClassName));
|
.loadClassByName(strategyClassName);
|
||||||
|
return new Header(type, strategyClass);
|
||||||
} else {
|
} else {
|
||||||
return new Header(type, null);
|
return new Header(type, null);
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.factory;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.linalg.factory.ops.*;
|
import org.nd4j.linalg.factory.ops.*;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.nd4j.shade.guava.primitives.Longs;
|
import org.nd4j.shade.guava.primitives.Longs;
|
||||||
|
@ -5135,37 +5136,36 @@ public class Nd4j {
|
||||||
compressDebug = pp.toBoolean(COMPRESSION_DEBUG);
|
compressDebug = pp.toBoolean(COMPRESSION_DEBUG);
|
||||||
char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
|
char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
|
||||||
|
|
||||||
Class<? extends BasicAffinityManager> affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
|
Class<? extends BasicAffinityManager> affinityManagerClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(AFFINITY_MANAGER));
|
.loadClassByName(pp.toString(AFFINITY_MANAGER));
|
||||||
affinityManager = affinityManagerClazz.newInstance();
|
affinityManager = affinityManagerClazz.newInstance();
|
||||||
Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
Class<? extends NDArrayFactory> ndArrayFactoryClazz = ND4JClassLoading
|
||||||
pp.toString(NDARRAY_FACTORY_CLASS));
|
.loadClassByName(pp.toString(NDARRAY_FACTORY_CLASS));
|
||||||
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
Class<? extends ConvolutionInstance> convolutionInstanceClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
.loadClassByName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
||||||
String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory");
|
String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory");
|
||||||
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
|
Class<? extends DataBufferFactory> dataBufferFactoryClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
.loadClassByName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
||||||
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(SHAPEINFO_PROVIDER));
|
.loadClassByName(pp.toString(SHAPEINFO_PROVIDER));
|
||||||
|
|
||||||
Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
|
Class<? extends BasicConstantHandler> constantProviderClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(CONSTANT_PROVIDER));
|
.loadClassByName(pp.toString(CONSTANT_PROVIDER));
|
||||||
|
|
||||||
Class<? extends BasicMemoryManager> memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
|
Class<? extends BasicMemoryManager> memoryManagerClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(MEMORY_MANAGER));
|
.loadClassByName(pp.toString(MEMORY_MANAGER));
|
||||||
|
|
||||||
allowsOrder = backend.allowsOrder();
|
allowsOrder = backend.allowsOrder();
|
||||||
String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName());
|
String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName());
|
||||||
Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
|
Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz = ND4JClassLoading.loadClassByName(rand);
|
||||||
randomFactory = new RandomFactory(randomClazz);
|
randomFactory = new RandomFactory(randomClazz);
|
||||||
|
|
||||||
Class<? extends MemoryWorkspaceManager> workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
|
Class<? extends MemoryWorkspaceManager> workspaceManagerClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(WORKSPACE_MANAGER));
|
.loadClassByName(pp.toString(WORKSPACE_MANAGER));
|
||||||
|
|
||||||
Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
Class<? extends BlasWrapper> blasWrapperClazz = ND4JClassLoading.loadClassByName(pp.toString(BLAS_OPS));
|
||||||
.forName(pp.toString(BLAS_OPS));
|
|
||||||
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
|
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
|
||||||
Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
|
Class<? extends DistributionFactory> distributionFactoryClazz = ND4JClassLoading.loadClassByName(clazzName);
|
||||||
|
|
||||||
|
|
||||||
memoryManager = memoryManagerClazz.newInstance();
|
memoryManager = memoryManagerClazz.newInstance();
|
||||||
|
@ -5173,8 +5173,8 @@ public class Nd4j {
|
||||||
shapeInfoProvider = shapeInfoProviderClazz.newInstance();
|
shapeInfoProvider = shapeInfoProviderClazz.newInstance();
|
||||||
workspaceManager = workspaceManagerClazz.newInstance();
|
workspaceManager = workspaceManagerClazz.newInstance();
|
||||||
|
|
||||||
Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
|
Class<? extends OpExecutioner> opExecutionerClazz = ND4JClassLoading
|
||||||
.forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
|
.loadClassByName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
|
||||||
|
|
||||||
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
||||||
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
|
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
|
||||||
|
@ -5197,7 +5197,7 @@ public class Nd4j {
|
||||||
OP_EXECUTIONER_INSTANCE.printEnvironmentInformation();
|
OP_EXECUTIONER_INSTANCE.printEnvironmentInformation();
|
||||||
}
|
}
|
||||||
|
|
||||||
val actions = ServiceLoader.load(EnvironmentalAction.class);
|
val actions = ND4JClassLoading.loadService(EnvironmentalAction.class);
|
||||||
val mappedActions = new HashMap<String, EnvironmentalAction>();
|
val mappedActions = new HashMap<String, EnvironmentalAction>();
|
||||||
for (val a: actions) {
|
for (val a: actions) {
|
||||||
if (!mappedActions.containsKey(a.targetVariable()))
|
if (!mappedActions.containsKey(a.targetVariable()))
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.nd4j.linalg.factory;
|
package org.nd4j.linalg.factory;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.common.config.ND4JEnvironmentVars;
|
import org.nd4j.common.config.ND4JEnvironmentVars;
|
||||||
import org.nd4j.common.config.ND4JSystemProperties;
|
import org.nd4j.common.config.ND4JSystemProperties;
|
||||||
import org.nd4j.context.Nd4jContext;
|
import org.nd4j.context.Nd4jContext;
|
||||||
|
@ -25,6 +26,7 @@ import org.nd4j.common.io.Resource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.net.URLClassLoader;
|
||||||
import java.security.PrivilegedActionException;
|
import java.security.PrivilegedActionException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -156,14 +158,12 @@ public abstract class Nd4jBackend {
|
||||||
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
|
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
|
||||||
boolean logInit = Boolean.parseBoolean(logInitProperty);
|
boolean logInit = Boolean.parseBoolean(logInitProperty);
|
||||||
|
|
||||||
List<Nd4jBackend> backends = new ArrayList<>(1);
|
List<Nd4jBackend> backends = new ArrayList<>();
|
||||||
ServiceLoader<Nd4jBackend> loader = ServiceLoader.load(Nd4jBackend.class);
|
ServiceLoader<Nd4jBackend> loader = ND4JClassLoading.loadService(Nd4jBackend.class);
|
||||||
try {
|
try {
|
||||||
|
for (Nd4jBackend nd4jBackend : loader) {
|
||||||
Iterator<Nd4jBackend> backendIterator = loader.iterator();
|
backends.add(nd4jBackend);
|
||||||
while (backendIterator.hasNext())
|
}
|
||||||
backends.add(backendIterator.next());
|
|
||||||
|
|
||||||
} catch (ServiceConfigurationError serviceError) {
|
} catch (ServiceConfigurationError serviceError) {
|
||||||
// a fatal error due to a syntax or provider construction error.
|
// a fatal error due to a syntax or provider construction error.
|
||||||
// backends mustn't throw an exception during construction.
|
// backends mustn't throw an exception during construction.
|
||||||
|
@ -240,7 +240,7 @@ public abstract class Nd4jBackend {
|
||||||
public static synchronized void loadLibrary(File jar) throws NoAvailableBackendException {
|
public static synchronized void loadLibrary(File jar) throws NoAvailableBackendException {
|
||||||
try {
|
try {
|
||||||
/*We are using reflection here to circumvent encapsulation; addURL is not public*/
|
/*We are using reflection here to circumvent encapsulation; addURL is not public*/
|
||||||
java.net.URLClassLoader loader = (java.net.URLClassLoader) ClassLoader.getSystemClassLoader();
|
java.net.URLClassLoader loader = (URLClassLoader) ND4JClassLoading.getNd4jClassloader();
|
||||||
java.net.URL url = jar.toURI().toURL();
|
java.net.URL url = jar.toURI().toURL();
|
||||||
/*Disallow if already loaded*/
|
/*Disallow if already loaded*/
|
||||||
for (java.net.URL it : java.util.Arrays.asList(loader.getURLs())) {
|
for (java.net.URL it : java.util.Arrays.asList(loader.getURLs())) {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.serde.json;
|
package org.nd4j.serde.json;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.shade.jackson.core.JsonParser;
|
import org.nd4j.shade.jackson.core.JsonParser;
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||||
|
@ -28,6 +29,7 @@ import java.util.ArrayList;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A base deserialization class used to handle deserializing of a specific class given changes from subtype wrapper
|
* A base deserialization class used to handle deserializing of a specific class given changes from subtype wrapper
|
||||||
|
@ -79,13 +81,9 @@ public abstract class BaseLegacyDeserializer<T> extends JsonDeserializer<T> {
|
||||||
+ "\": legacy class mapping with this name is unknown");
|
+ "\": legacy class mapping with this name is unknown");
|
||||||
}
|
}
|
||||||
|
|
||||||
Class<? extends T> lClass;
|
Class<? extends T> lClass = ND4JClassLoading.loadClassByName(layerClass);
|
||||||
try {
|
Objects.requireNonNull(lClass, "Could not find class for deserialization of \"" + name + "\" of type " +
|
||||||
lClass = (Class<? extends T>) Class.forName(layerClass);
|
getDeserializedType() + ": class " + layerClass + " is not on the classpath?");
|
||||||
} catch (Exception e){
|
|
||||||
throw new RuntimeException("Could not find class for deserialization of \"" + name + "\" of type " +
|
|
||||||
getDeserializedType() + ": class " + layerClass + " is not on the classpath?", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
ObjectMapper m = getLegacyJsonMapper();
|
ObjectMapper m = getLegacyJsonMapper();
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.lang3.SystemUtils;
|
import org.apache.commons.lang3.SystemUtils;
|
||||||
import org.apache.commons.lang3.exception.ExceptionUtils;
|
import org.apache.commons.lang3.exception.ExceptionUtils;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
|
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -212,22 +213,22 @@ public class SystemInfo {
|
||||||
|
|
||||||
boolean hasGPUs = false;
|
boolean hasGPUs = false;
|
||||||
|
|
||||||
ServiceLoader<GPUInfoProvider> loader = ServiceLoader.load(GPUInfoProvider.class);
|
ServiceLoader<GPUInfoProvider> loader = ND4JClassLoading.loadService(GPUInfoProvider.class);
|
||||||
Iterator<GPUInfoProvider> iter = loader.iterator();
|
Iterator<GPUInfoProvider> iter = loader.iterator();
|
||||||
if(iter.hasNext()){
|
if (iter.hasNext()) {
|
||||||
List<GPUInfo> gpus = iter.next().getGPUs();
|
List<GPUInfo> gpus = iter.next().getGPUs();
|
||||||
|
|
||||||
sb.append(f("Number of GPUs Detected", gpus.size()));
|
sb.append(f("Number of GPUs Detected", gpus.size()));
|
||||||
|
|
||||||
if(!gpus.isEmpty())
|
if (!gpus.isEmpty()) {
|
||||||
hasGPUs = true;
|
hasGPUs = true;
|
||||||
|
}
|
||||||
|
|
||||||
sb.append(String.format(fGpu, "Name", "CC", "Total Memory", "Used Memory", "Free Memory")).append("\n");
|
sb.append(String.format(fGpu, "Name", "CC", "Total Memory", "Used Memory", "Free Memory")).append("\n");
|
||||||
|
|
||||||
for(GPUInfo gpuInfo : gpus){
|
for (GPUInfo gpuInfo : gpus) {
|
||||||
sb.append(gpuInfo).append("\n");
|
sb.append(gpuInfo).append("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
sb.append("GPU Provider not found (are you missing nd4j-native?)");
|
sb.append("GPU Provider not found (are you missing nd4j-native?)");
|
||||||
}
|
}
|
||||||
|
@ -327,28 +328,24 @@ public class SystemInfo {
|
||||||
|
|
||||||
appendProperty(sb, "Library Path", "java.library.path");
|
appendProperty(sb, "Library Path", "java.library.path");
|
||||||
|
|
||||||
|
|
||||||
//classpath
|
|
||||||
appendHeader(sb, "Classpath");
|
appendHeader(sb, "Classpath");
|
||||||
ClassLoader cl = ClassLoader.getSystemClassLoader();
|
|
||||||
|
|
||||||
URL[] urls = null;
|
URLClassLoader urlClassLoader = null;
|
||||||
try{
|
|
||||||
urls = ((URLClassLoader)cl).getURLs();
|
if (ND4JClassLoading.getNd4jClassloader() instanceof URLClassLoader) {
|
||||||
} catch (ClassCastException e){
|
urlClassLoader = (URLClassLoader) ND4JClassLoading.getNd4jClassloader();
|
||||||
try {
|
} else if (ClassLoader.getSystemClassLoader() instanceof URLClassLoader) {
|
||||||
urls = ((URLClassLoader) SystemInfo.class.getClassLoader()).getURLs();
|
urlClassLoader = (URLClassLoader) ClassLoader.getSystemClassLoader();
|
||||||
} catch (ClassCastException e1){
|
} else if (SystemInfo.class.getClassLoader() instanceof URLClassLoader) {
|
||||||
try{
|
urlClassLoader = (URLClassLoader) SystemInfo.class.getClassLoader();
|
||||||
urls = ((URLClassLoader) (Thread.currentThread().getContextClassLoader())).getURLs();
|
} else if (Thread.currentThread().getContextClassLoader() instanceof URLClassLoader) {
|
||||||
} catch (ClassCastException e2) {
|
urlClassLoader = (URLClassLoader) Thread.currentThread().getContextClassLoader();
|
||||||
sb.append("Can't cast class loader to URLClassLoader\n");
|
} else {
|
||||||
}
|
sb.append("Can't cast class loader to URLClassLoader\n");
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(urls != null) {
|
if (urlClassLoader != null) {
|
||||||
for (URL url : urls) {
|
for (URL url : urlClassLoader.getURLs()) {
|
||||||
sb.append(url.getFile()).append("\n");
|
sb.append(url.getFile()).append("\n");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -359,7 +356,6 @@ public class SystemInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//launch command
|
//launch command
|
||||||
appendHeader(sb, "Launch Command");
|
appendHeader(sb, "Launch Command");
|
||||||
|
|
||||||
|
|
|
@ -17,14 +17,27 @@
|
||||||
package org.nd4j.versioncheck;
|
package org.nd4j.versioncheck;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.common.config.ND4JSystemProperties;
|
import org.nd4j.common.config.ND4JSystemProperties;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.file.*;
|
import java.nio.file.FileSystem;
|
||||||
|
import java.nio.file.FileSystems;
|
||||||
|
import java.nio.file.FileVisitResult;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.nio.file.SimpleFileVisitor;
|
||||||
import java.nio.file.attribute.BasicFileAttributes;
|
import java.nio.file.attribute.BasicFileAttributes;
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Enumeration;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A runtime version check utility that does 2 things:<br>
|
* A runtime version check utility that does 2 things:<br>
|
||||||
|
@ -92,14 +105,14 @@ public class VersionCheck {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(classExists(ND4J_JBLAS_CLASS)) {
|
if(ND4JClassLoading.classPresentOnClasspath(ND4J_JBLAS_CLASS)) {
|
||||||
//nd4j-jblas is ancient and incompatible
|
//nd4j-jblas is ancient and incompatible
|
||||||
log.error("Found incompatible/obsolete backend and version (nd4j-jblas) on classpath. ND4J is unlikely to"
|
log.error("Found incompatible/obsolete backend and version (nd4j-jblas) on classpath. ND4J is unlikely to"
|
||||||
+ " function correctly with nd4j-jblas on the classpath. JVM will now exit.");
|
+ " function correctly with nd4j-jblas on the classpath. JVM will now exit.");
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(classExists(CANOVA_CLASS)) {
|
if(ND4JClassLoading.classPresentOnClasspath(CANOVA_CLASS)) {
|
||||||
//Canova is ancient and likely to pull in incompatible dependencies
|
//Canova is ancient and likely to pull in incompatible dependencies
|
||||||
log.error("Found incompatible/obsolete library Canova on classpath. ND4J is unlikely to"
|
log.error("Found incompatible/obsolete library Canova on classpath. ND4J is unlikely to"
|
||||||
+ " function correctly with this library on the classpath. JVM will now exit.");
|
+ " function correctly with this library on the classpath. JVM will now exit.");
|
||||||
|
@ -281,13 +294,13 @@ public class VersionCheck {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(classExists(ND4J_JBLAS_CLASS)){
|
if(ND4JClassLoading.classPresentOnClasspath(ND4J_JBLAS_CLASS)){
|
||||||
//nd4j-jblas is ancient and incompatible
|
//nd4j-jblas is ancient and incompatible
|
||||||
log.error("Found incompatible/obsolete backend and version (nd4j-jblas) on classpath. ND4J is unlikely to"
|
log.error("Found incompatible/obsolete backend and version (nd4j-jblas) on classpath. ND4J is unlikely to"
|
||||||
+ " function correctly with nd4j-jblas on the classpath.");
|
+ " function correctly with nd4j-jblas on the classpath.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if(classExists(CANOVA_CLASS)){
|
if(ND4JClassLoading.classPresentOnClasspath(CANOVA_CLASS)){
|
||||||
//Canova is anchient and likely to pull in incompatible
|
//Canova is anchient and likely to pull in incompatible
|
||||||
log.error("Found incompatible/obsolete library Canova on classpath. ND4J is unlikely to"
|
log.error("Found incompatible/obsolete library Canova on classpath. ND4J is unlikely to"
|
||||||
+ " function correctly with this library on the classpath.");
|
+ " function correctly with this library on the classpath.");
|
||||||
|
@ -296,16 +309,6 @@ public class VersionCheck {
|
||||||
return repState;
|
return repState;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean classExists(String className){
|
|
||||||
try{
|
|
||||||
Class.forName(className);
|
|
||||||
return true;
|
|
||||||
} catch (ClassNotFoundException e ){
|
|
||||||
//OK - not found
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A string representation of the version information, with the default (GAV) detail level
|
* @return A string representation of the version information, with the default (GAV) detail level
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -19,8 +19,10 @@ package org.nd4j.nativeblas;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.bytedeco.javacpp.Loader;
|
import org.bytedeco.javacpp.Loader;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.common.config.ND4JEnvironmentVars;
|
import org.nd4j.common.config.ND4JEnvironmentVars;
|
||||||
import org.nd4j.common.config.ND4JSystemProperties;
|
import org.nd4j.common.config.ND4JSystemProperties;
|
||||||
|
import org.nd4j.common.io.ReflectionUtils;
|
||||||
import org.nd4j.context.Nd4jContext;
|
import org.nd4j.context.Nd4jContext;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -82,8 +84,10 @@ public class NativeOpsHolder {
|
||||||
Properties props = Nd4jContext.getInstance().getConf();
|
Properties props = Nd4jContext.getInstance().getConf();
|
||||||
|
|
||||||
String name = System.getProperty(Nd4j.NATIVE_OPS, props.get(Nd4j.NATIVE_OPS).toString());
|
String name = System.getProperty(Nd4j.NATIVE_OPS, props.get(Nd4j.NATIVE_OPS).toString());
|
||||||
Class<? extends NativeOps> nativeOpsClazz = Class.forName(name).asSubclass(NativeOps.class);
|
Class<? extends NativeOps> nativeOpsClass = ND4JClassLoading
|
||||||
deviceNativeOps = nativeOpsClazz.newInstance();
|
.loadClassByName(name)
|
||||||
|
.asSubclass(NativeOps.class);
|
||||||
|
deviceNativeOps = ReflectionUtils.newInstance(nativeOpsClass);
|
||||||
|
|
||||||
deviceNativeOps.initializeDevicesAndFunctions();
|
deviceNativeOps.initializeDevicesAndFunctions();
|
||||||
int numThreads;
|
int numThreads;
|
||||||
|
|
|
@ -16,18 +16,18 @@
|
||||||
|
|
||||||
package org.nd4j.linalg;
|
package org.nd4j.linalg;
|
||||||
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
|
import org.nd4j.common.io.ReflectionUtils;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base Nd4j test
|
* Base Nd4j test
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
|
@ -35,6 +35,18 @@ import java.util.*;
|
||||||
@RunWith(Parameterized.class)
|
@RunWith(Parameterized.class)
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseNd4jTest extends BaseND4JTest {
|
public abstract class BaseNd4jTest extends BaseND4JTest {
|
||||||
|
private static List<Nd4jBackend> BACKENDS = new ArrayList<>();
|
||||||
|
static {
|
||||||
|
List<String> backendsToRun = Nd4jTestSuite.backendsToRun();
|
||||||
|
|
||||||
|
ServiceLoader<Nd4jBackend> loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class);
|
||||||
|
for (Nd4jBackend backend : loadedBackends) {
|
||||||
|
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName())
|
||||||
|
|| backendsToRun.isEmpty()) {
|
||||||
|
BACKENDS.add(backend);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected Nd4jBackend backend;
|
protected Nd4jBackend backend;
|
||||||
protected String name;
|
protected String name;
|
||||||
|
@ -57,24 +69,10 @@ public abstract class BaseNd4jTest extends BaseND4JTest {
|
||||||
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
|
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Nd4jBackend> backends;
|
|
||||||
static {
|
|
||||||
ServiceLoader<Nd4jBackend> loadedBackends = ServiceLoader.load(Nd4jBackend.class);
|
|
||||||
Iterator<Nd4jBackend> backendIterator = loadedBackends.iterator();
|
|
||||||
backends = new ArrayList<>();
|
|
||||||
List<String> backendsToRun = Nd4jTestSuite.backendsToRun();
|
|
||||||
|
|
||||||
while (backendIterator.hasNext()) {
|
|
||||||
Nd4jBackend backend = backendIterator.next();
|
|
||||||
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty())
|
|
||||||
backends.add(backend);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
|
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
|
||||||
public static Collection<Object[]> configs() {
|
public static Collection<Object[]> configs() {
|
||||||
List<Object[]> ret = new ArrayList<>();
|
List<Object[]> ret = new ArrayList<>();
|
||||||
for (Nd4jBackend backend : backends)
|
for (Nd4jBackend backend : BACKENDS)
|
||||||
ret.add(new Object[] {backend});
|
ret.add(new Object[] {backend});
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -93,16 +91,11 @@ public abstract class BaseNd4jTest extends BaseND4JTest {
|
||||||
*/
|
*/
|
||||||
public static Nd4jBackend getDefaultBackend() {
|
public static Nd4jBackend getDefaultBackend() {
|
||||||
String cpuBackend = "org.nd4j.linalg.cpu.nativecpu.CpuBackend";
|
String cpuBackend = "org.nd4j.linalg.cpu.nativecpu.CpuBackend";
|
||||||
//String cpuBackend = "org.nd4j.linalg.cpu.CpuBackend";
|
String defaultBackendClass = System.getProperty(DEFAULT_BACKEND, cpuBackend);
|
||||||
String gpuBackend = "org.nd4j.linalg.jcublas.JCublasBackend";
|
|
||||||
String clazz = System.getProperty(DEFAULT_BACKEND, cpuBackend);
|
|
||||||
try {
|
|
||||||
return (Nd4jBackend) Class.forName(clazz).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
Class<Nd4jBackend> backendClass = ND4JClassLoading.loadClassByName(defaultBackendClass);
|
||||||
|
return ReflectionUtils.newInstance(backendClass);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The ordering for this test
|
* The ordering for this test
|
||||||
|
|
|
@ -17,10 +17,10 @@
|
||||||
package org.nd4j.linalg;
|
package org.nd4j.linalg;
|
||||||
|
|
||||||
import org.junit.runners.BlockJUnit4ClassRunner;
|
import org.junit.runners.BlockJUnit4ClassRunner;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.ServiceLoader;
|
import java.util.ServiceLoader;
|
||||||
|
|
||||||
|
@ -32,19 +32,15 @@ import java.util.ServiceLoader;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public class Nd4jTestSuite extends BlockJUnit4ClassRunner {
|
public class Nd4jTestSuite extends BlockJUnit4ClassRunner {
|
||||||
//the system property for what backends should run
|
//the system property for what backends should run
|
||||||
public final static String BACKENDS_TO_LOAD = "backends";
|
public final static String BACKENDS_TO_LOAD = "backends";
|
||||||
private static List<Nd4jBackend> backends;
|
private static List<Nd4jBackend> BACKENDS;
|
||||||
static {
|
static {
|
||||||
ServiceLoader<Nd4jBackend> loadedBackends = ServiceLoader.load(Nd4jBackend.class);
|
ServiceLoader<Nd4jBackend> loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class);
|
||||||
Iterator<Nd4jBackend> backendIterator = loadedBackends.iterator();
|
for (Nd4jBackend backend : loadedBackends) {
|
||||||
backends = new ArrayList<>();
|
BACKENDS.add(backend);
|
||||||
while (backendIterator.hasNext())
|
}
|
||||||
backends.add(backendIterator.next());
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -56,7 +52,6 @@ public class Nd4jTestSuite extends BlockJUnit4ClassRunner {
|
||||||
super(klass);
|
super(klass);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Based on the jvm arguments, an empty list is returned
|
* Based on the jvm arguments, an empty list is returned
|
||||||
* if all backends should be run.
|
* if all backends should be run.
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.nd4j.common.base;
|
package org.nd4j.common.base;
|
||||||
|
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -23,24 +25,20 @@ import java.util.*;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class Preconditions {
|
public final class Preconditions {
|
||||||
|
private static final Map<String,PreconditionsFormat> FORMATTERS = new HashMap<>();
|
||||||
private static final Map<String,PreconditionsFormat> formatters = new HashMap<>();
|
|
||||||
|
|
||||||
static {
|
static {
|
||||||
ServiceLoader<PreconditionsFormat> sl = ServiceLoader.load(PreconditionsFormat.class);
|
ServiceLoader<PreconditionsFormat> sl = ND4JClassLoading.loadService(PreconditionsFormat.class);
|
||||||
Iterator<PreconditionsFormat> iter = sl.iterator();
|
for (PreconditionsFormat pf : sl) {
|
||||||
while(iter.hasNext()){
|
|
||||||
PreconditionsFormat pf = iter.next();
|
|
||||||
List<String> formatTags = pf.formatTags();
|
List<String> formatTags = pf.formatTags();
|
||||||
for(String s : formatTags){
|
for(String s : formatTags){
|
||||||
formatters.put(s, pf);
|
FORMATTERS.put(s, pf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Preconditions(){ }
|
private Preconditions() {
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check the specified boolean argument. Throws an IllegalArgumentException if {@code b} is false
|
* Check the specified boolean argument. Throws an IllegalArgumentException if {@code b} is false
|
||||||
|
@ -664,7 +662,7 @@ public class Preconditions {
|
||||||
|
|
||||||
int nextCustom = -1;
|
int nextCustom = -1;
|
||||||
String nextCustomTag = null;
|
String nextCustomTag = null;
|
||||||
for(String s : formatters.keySet()){
|
for(String s : FORMATTERS.keySet()){
|
||||||
int idxThis = message.indexOf(s, indexOfStart);
|
int idxThis = message.indexOf(s, indexOfStart);
|
||||||
if(idxThis > 0 && (nextCustom < 0 || idxThis < nextCustom)){
|
if(idxThis > 0 && (nextCustom < 0 || idxThis < nextCustom)){
|
||||||
nextCustom = idxThis;
|
nextCustom = idxThis;
|
||||||
|
@ -696,7 +694,7 @@ public class Preconditions {
|
||||||
} else {
|
} else {
|
||||||
//Custom tag
|
//Custom tag
|
||||||
sb.append(message.substring(indexOfStart, nextCustom));
|
sb.append(message.substring(indexOfStart, nextCustom));
|
||||||
String s = formatters.get(nextCustomTag).format(nextCustomTag, args[i]);
|
String s = FORMATTERS.get(nextCustomTag).format(nextCustomTag, args[i]);
|
||||||
sb.append(s);
|
sb.append(s);
|
||||||
indexOfStart = nextCustom + nextCustomTag.length();
|
indexOfStart = nextCustom + nextCustomTag.length();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package org.nd4j.common.config;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.ServiceLoader;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Global context for class-loading in ND4J.
|
||||||
|
*
|
||||||
|
* @author Alexei KLENIN
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public final class ND4JClassLoading {
|
||||||
|
private static ClassLoader nd4jClassloader = Thread.currentThread().getContextClassLoader();
|
||||||
|
|
||||||
|
private ND4JClassLoading() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ClassLoader getNd4jClassloader() {
|
||||||
|
return ND4JClassLoading.nd4jClassloader;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setNd4jClassloaderFromClass(Class<?> clazz) {
|
||||||
|
setNd4jClassloader(clazz.getClassLoader());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setNd4jClassloader(ClassLoader nd4jClassloader) {
|
||||||
|
ND4JClassLoading.nd4jClassloader = nd4jClassloader;
|
||||||
|
log.debug("Global class-loader for ND4J was changed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean classPresentOnClasspath(String className) {
|
||||||
|
return classPresentOnClasspath(className, nd4jClassloader);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean classPresentOnClasspath(String className, ClassLoader classLoader) {
|
||||||
|
return loadClassByName(className, false, classLoader) != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> Class<T> loadClassByName(String className) {
|
||||||
|
return loadClassByName(className, true, nd4jClassloader);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public static <T> Class<T> loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
|
||||||
|
try {
|
||||||
|
return (Class<T>) Class.forName(className, initialize, classLoader);
|
||||||
|
} catch (ClassNotFoundException classNotFoundException) {
|
||||||
|
log.error(String.format("Cannot find class [%s] of provided class-loader.", className));
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass) {
|
||||||
|
return loadService(serviceClass, nd4jClassloader);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass, ClassLoader classLoader) {
|
||||||
|
return ServiceLoader.load(serviceClass, classLoader);
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.net.MalformedURLException;
|
import java.net.MalformedURLException;
|
||||||
|
@ -55,7 +56,7 @@ public class ClassPathResource extends AbstractFileResolvingResource {
|
||||||
}
|
}
|
||||||
|
|
||||||
this.path = pathToUse;
|
this.path = pathToUse;
|
||||||
this.classLoader = classLoader != null ? classLoader : ClassUtils.getDefaultClassLoader();
|
this.classLoader = classLoader != null ? classLoader : ND4JClassLoading.getNd4jClassloader();
|
||||||
}
|
}
|
||||||
|
|
||||||
public ClassPathResource(String path, Class<?> clazz) {
|
public ClassPathResource(String path, Class<?> clazz) {
|
||||||
|
@ -283,7 +284,7 @@ public class ClassPathResource extends AbstractFileResolvingResource {
|
||||||
StringBuilder builder = new StringBuilder("class path resource [");
|
StringBuilder builder = new StringBuilder("class path resource [");
|
||||||
String pathToUse = this.path;
|
String pathToUse = this.path;
|
||||||
if (this.clazz != null && !pathToUse.startsWith("/")) {
|
if (this.clazz != null && !pathToUse.startsWith("/")) {
|
||||||
builder.append(ClassUtils.classPackageAsResourcePath(this.clazz));
|
builder.append(ResourceUtils.classPackageAsResourcePath(this.clazz));
|
||||||
builder.append('/');
|
builder.append('/');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,7 +321,7 @@ public class ClassPathResource extends AbstractFileResolvingResource {
|
||||||
private URL getUrl() {
|
private URL getUrl() {
|
||||||
ClassLoader loader = null;
|
ClassLoader loader = null;
|
||||||
try {
|
try {
|
||||||
loader = Thread.currentThread().getContextClassLoader();
|
loader = ND4JClassLoading.getNd4jClassloader();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
// do nothing
|
// do nothing
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,710 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.common.io;
|
|
||||||
|
|
||||||
|
|
||||||
import java.beans.Introspector;
|
|
||||||
import java.lang.reflect.*;
|
|
||||||
import java.security.AccessControlException;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
|
|
||||||
public abstract class ClassUtils {
|
|
||||||
public static final String ARRAY_SUFFIX = "[]";
|
|
||||||
private static final String INTERNAL_ARRAY_PREFIX = "[";
|
|
||||||
private static final String NON_PRIMITIVE_ARRAY_PREFIX = "[L";
|
|
||||||
private static final char PACKAGE_SEPARATOR = '.';
|
|
||||||
private static final char INNER_CLASS_SEPARATOR = '$';
|
|
||||||
public static final String CGLIB_CLASS_SEPARATOR = "$$";
|
|
||||||
public static final String CLASS_FILE_SUFFIX = ".class";
|
|
||||||
private static final Map<Class<?>, Class<?>> primitiveWrapperTypeMap = new HashMap(8);
|
|
||||||
private static final Map<Object, Object> primitiveTypeToWrapperMap = new HashMap(8);
|
|
||||||
private static final Map<String, Class<?>> primitiveTypeNameMap = new HashMap(32);
|
|
||||||
private static final Map<String, Class<?>> commonClassCache = new HashMap(32);
|
|
||||||
|
|
||||||
public ClassUtils() {}
|
|
||||||
|
|
||||||
private static void registerCommonClasses(Class<?>... commonClasses) {
|
|
||||||
Class[] arr$ = commonClasses;
|
|
||||||
int len$ = commonClasses.length;
|
|
||||||
|
|
||||||
for (int i$ = 0; i$ < len$; ++i$) {
|
|
||||||
Class clazz = arr$[i$];
|
|
||||||
commonClassCache.put(clazz.getName(), clazz);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ClassLoader getDefaultClassLoader() {
|
|
||||||
ClassLoader cl = null;
|
|
||||||
|
|
||||||
try {
|
|
||||||
cl = Thread.currentThread().getContextClassLoader();
|
|
||||||
} catch (Throwable var2) {
|
|
||||||
;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cl == null) {
|
|
||||||
cl = ClassUtils.class.getClassLoader();
|
|
||||||
}
|
|
||||||
|
|
||||||
return cl;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ClassLoader overrideThreadContextClassLoader(ClassLoader classLoaderToUse) {
|
|
||||||
Thread currentThread = Thread.currentThread();
|
|
||||||
ClassLoader threadContextClassLoader = currentThread.getContextClassLoader();
|
|
||||||
if (classLoaderToUse != null && !classLoaderToUse.equals(threadContextClassLoader)) {
|
|
||||||
currentThread.setContextClassLoader(classLoaderToUse);
|
|
||||||
return threadContextClassLoader;
|
|
||||||
} else {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @deprecated */
|
|
||||||
@Deprecated
|
|
||||||
public static Class<?> forName(String name) throws ClassNotFoundException, LinkageError {
|
|
||||||
return forName(name, getDefaultClassLoader());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?> forName(String name, ClassLoader classLoader) throws ClassNotFoundException, LinkageError {
|
|
||||||
Assert.notNull(name, "Name must not be null");
|
|
||||||
Class clazz = resolvePrimitiveClassName(name);
|
|
||||||
if (clazz == null) {
|
|
||||||
clazz = commonClassCache.get(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (clazz != null) {
|
|
||||||
return clazz;
|
|
||||||
} else {
|
|
||||||
Class ex;
|
|
||||||
String classLoaderToUse1;
|
|
||||||
if (name.endsWith("[]")) {
|
|
||||||
classLoaderToUse1 = name.substring(0, name.length() - "[]".length());
|
|
||||||
ex = forName(classLoaderToUse1, classLoader);
|
|
||||||
return Array.newInstance(ex, 0).getClass();
|
|
||||||
} else if (name.startsWith("[L") && name.endsWith(";")) {
|
|
||||||
classLoaderToUse1 = name.substring("[L".length(), name.length() - 1);
|
|
||||||
ex = forName(classLoaderToUse1, classLoader);
|
|
||||||
return Array.newInstance(ex, 0).getClass();
|
|
||||||
} else if (name.startsWith("[")) {
|
|
||||||
classLoaderToUse1 = name.substring("[".length());
|
|
||||||
ex = forName(classLoaderToUse1, classLoader);
|
|
||||||
return Array.newInstance(ex, 0).getClass();
|
|
||||||
} else {
|
|
||||||
ClassLoader classLoaderToUse = classLoader;
|
|
||||||
if (classLoader == null) {
|
|
||||||
classLoaderToUse = getDefaultClassLoader();
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
return classLoaderToUse.loadClass(name);
|
|
||||||
} catch (ClassNotFoundException var9) {
|
|
||||||
int lastDotIndex = name.lastIndexOf(46);
|
|
||||||
if (lastDotIndex != -1) {
|
|
||||||
String innerClassName =
|
|
||||||
name.substring(0, lastDotIndex) + '$' + name.substring(lastDotIndex + 1);
|
|
||||||
|
|
||||||
try {
|
|
||||||
return classLoaderToUse.loadClass(innerClassName);
|
|
||||||
} catch (ClassNotFoundException var8) {
|
|
||||||
;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw var9;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?> resolveClassName(String className, ClassLoader classLoader) throws IllegalArgumentException {
|
|
||||||
try {
|
|
||||||
return forName(className, classLoader);
|
|
||||||
} catch (ClassNotFoundException var3) {
|
|
||||||
throw new IllegalArgumentException("Cannot find class [" + className + "]", var3);
|
|
||||||
} catch (LinkageError var4) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Error loading class [" + className + "]: problem with class file or dependent class.",
|
|
||||||
var4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?> resolvePrimitiveClassName(String name) {
|
|
||||||
Class result = null;
|
|
||||||
if (name != null && name.length() <= 8) {
|
|
||||||
result = primitiveTypeNameMap.get(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @deprecated */
|
|
||||||
@Deprecated
|
|
||||||
public static boolean isPresent(String className) {
|
|
||||||
return isPresent(className, getDefaultClassLoader());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isPresent(String className, ClassLoader classLoader) {
|
|
||||||
try {
|
|
||||||
forName(className, classLoader);
|
|
||||||
return true;
|
|
||||||
} catch (Throwable var3) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?> getUserClass(Object instance) {
|
|
||||||
Assert.notNull(instance, "Instance must not be null");
|
|
||||||
return getUserClass((Class) instance.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?> getUserClass(Class<?> clazz) {
|
|
||||||
if (clazz != null && clazz.getName().contains("$$")) {
|
|
||||||
Class superClass = clazz.getSuperclass();
|
|
||||||
if (superClass != null && !Object.class.equals(superClass)) {
|
|
||||||
return superClass;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return clazz;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isCacheSafe(Class<?> clazz, ClassLoader classLoader) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
ClassLoader target = clazz.getClassLoader();
|
|
||||||
if (target == null) {
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
ClassLoader cur = classLoader;
|
|
||||||
if (classLoader == target) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
do {
|
|
||||||
if (cur == null) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
cur = cur.getParent();
|
|
||||||
} while (cur != target);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getShortName(String className) {
|
|
||||||
Assert.hasLength(className, "Class name must not be empty");
|
|
||||||
int lastDotIndex = className.lastIndexOf(46);
|
|
||||||
int nameEndIndex = className.indexOf("$$");
|
|
||||||
if (nameEndIndex == -1) {
|
|
||||||
nameEndIndex = className.length();
|
|
||||||
}
|
|
||||||
|
|
||||||
String shortName = className.substring(lastDotIndex + 1, nameEndIndex);
|
|
||||||
shortName = shortName.replace('$', '.');
|
|
||||||
return shortName;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getShortName(Class<?> clazz) {
|
|
||||||
return getShortName(getQualifiedName(clazz));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getShortNameAsProperty(Class<?> clazz) {
|
|
||||||
String shortName = getShortName((Class) clazz);
|
|
||||||
int dotIndex = shortName.lastIndexOf(46);
|
|
||||||
shortName = dotIndex != -1 ? shortName.substring(dotIndex + 1) : shortName;
|
|
||||||
return Introspector.decapitalize(shortName);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getClassFileName(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
String className = clazz.getName();
|
|
||||||
int lastDotIndex = className.lastIndexOf(46);
|
|
||||||
return className.substring(lastDotIndex + 1) + ".class";
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getPackageName(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
return getPackageName(clazz.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getPackageName(String fqClassName) {
|
|
||||||
Assert.notNull(fqClassName, "Class name must not be null");
|
|
||||||
int lastDotIndex = fqClassName.lastIndexOf(46);
|
|
||||||
return lastDotIndex != -1 ? fqClassName.substring(0, lastDotIndex) : "";
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getQualifiedName(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
return clazz.isArray() ? getQualifiedNameForArray(clazz) : clazz.getName();
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String getQualifiedNameForArray(Class<?> clazz) {
|
|
||||||
StringBuilder result = new StringBuilder();
|
|
||||||
|
|
||||||
while (clazz.isArray()) {
|
|
||||||
clazz = clazz.getComponentType();
|
|
||||||
result.append("[]");
|
|
||||||
}
|
|
||||||
|
|
||||||
result.insert(0, clazz.getName());
|
|
||||||
return result.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getQualifiedMethodName(Method method) {
|
|
||||||
Assert.notNull(method, "Method must not be null");
|
|
||||||
return method.getDeclaringClass().getName() + "." + method.getName();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String getDescriptiveType(Object value) {
|
|
||||||
if (value == null) {
|
|
||||||
return null;
|
|
||||||
} else {
|
|
||||||
Class clazz = value.getClass();
|
|
||||||
if (Proxy.isProxyClass(clazz)) {
|
|
||||||
StringBuilder result = new StringBuilder(clazz.getName());
|
|
||||||
result.append(" implementing ");
|
|
||||||
Class[] ifcs = clazz.getInterfaces();
|
|
||||||
|
|
||||||
for (int i = 0; i < ifcs.length; ++i) {
|
|
||||||
result.append(ifcs[i].getName());
|
|
||||||
if (i < ifcs.length - 1) {
|
|
||||||
result.append(',');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.toString();
|
|
||||||
} else {
|
|
||||||
return clazz.isArray() ? getQualifiedNameForArray(clazz) : clazz.getName();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean matchesTypeName(Class<?> clazz, String typeName) {
|
|
||||||
return typeName != null && (typeName.equals(clazz.getName()) || typeName.equals(clazz.getSimpleName())
|
|
||||||
|| clazz.isArray() && typeName.equals(getQualifiedNameForArray(clazz)));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean hasConstructor(Class<?> clazz, Class<?>... paramTypes) {
|
|
||||||
return getConstructorIfAvailable(clazz, paramTypes) != null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> Constructor<T> getConstructorIfAvailable(Class<T> clazz, Class<?>... paramTypes) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
|
|
||||||
try {
|
|
||||||
return clazz.getConstructor(paramTypes);
|
|
||||||
} catch (NoSuchMethodException var3) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean hasMethod(Class<?> clazz, String methodName, Class<?>... paramTypes) {
|
|
||||||
return getMethodIfAvailable(clazz, methodName, paramTypes) != null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Method getMethod(Class<?> clazz, String methodName, Class<?>... paramTypes) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
Assert.notNull(methodName, "Method name must not be null");
|
|
||||||
if (paramTypes != null) {
|
|
||||||
try {
|
|
||||||
return clazz.getMethod(methodName, paramTypes);
|
|
||||||
} catch (NoSuchMethodException var9) {
|
|
||||||
throw new IllegalStateException("Expected method not found: " + var9);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
HashSet candidates = new HashSet(1);
|
|
||||||
Method[] methods = clazz.getMethods();
|
|
||||||
Method[] arr$ = methods;
|
|
||||||
int len$ = methods.length;
|
|
||||||
|
|
||||||
for (int i$ = 0; i$ < len$; ++i$) {
|
|
||||||
Method method = arr$[i$];
|
|
||||||
if (methodName.equals(method.getName())) {
|
|
||||||
candidates.add(method);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (candidates.size() == 1) {
|
|
||||||
return (Method) candidates.iterator().next();
|
|
||||||
} else if (candidates.isEmpty()) {
|
|
||||||
throw new IllegalStateException("Expected method not found: " + clazz + "." + methodName);
|
|
||||||
} else {
|
|
||||||
throw new IllegalStateException("No unique method found: " + clazz + "." + methodName);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Method getMethodIfAvailable(Class<?> clazz, String methodName, Class<?>... paramTypes) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
Assert.notNull(methodName, "Method name must not be null");
|
|
||||||
if (paramTypes != null) {
|
|
||||||
try {
|
|
||||||
return clazz.getMethod(methodName, paramTypes);
|
|
||||||
} catch (NoSuchMethodException var9) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
HashSet candidates = new HashSet(1);
|
|
||||||
Method[] methods = clazz.getMethods();
|
|
||||||
Method[] arr$ = methods;
|
|
||||||
int len$ = methods.length;
|
|
||||||
|
|
||||||
for (int i$ = 0; i$ < len$; ++i$) {
|
|
||||||
Method method = arr$[i$];
|
|
||||||
if (methodName.equals(method.getName())) {
|
|
||||||
candidates.add(method);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (candidates.size() == 1) {
|
|
||||||
return (Method) candidates.iterator().next();
|
|
||||||
} else {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int getMethodCountForName(Class<?> clazz, String methodName) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
Assert.notNull(methodName, "Method name must not be null");
|
|
||||||
int count = 0;
|
|
||||||
Method[] declaredMethods = clazz.getDeclaredMethods();
|
|
||||||
Method[] ifcs = declaredMethods;
|
|
||||||
int arr$ = declaredMethods.length;
|
|
||||||
|
|
||||||
int len$;
|
|
||||||
for (len$ = 0; len$ < arr$; ++len$) {
|
|
||||||
Method i$ = ifcs[len$];
|
|
||||||
if (methodName.equals(i$.getName())) {
|
|
||||||
++count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Class[] var9 = clazz.getInterfaces();
|
|
||||||
Class[] var10 = var9;
|
|
||||||
len$ = var9.length;
|
|
||||||
|
|
||||||
for (int var11 = 0; var11 < len$; ++var11) {
|
|
||||||
Class ifc = var10[var11];
|
|
||||||
count += getMethodCountForName(ifc, methodName);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (clazz.getSuperclass() != null) {
|
|
||||||
count += getMethodCountForName(clazz.getSuperclass(), methodName);
|
|
||||||
}
|
|
||||||
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean hasAtLeastOneMethodWithName(Class<?> clazz, String methodName) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
Assert.notNull(methodName, "Method name must not be null");
|
|
||||||
Method[] declaredMethods = clazz.getDeclaredMethods();
|
|
||||||
Method[] ifcs = declaredMethods;
|
|
||||||
int arr$ = declaredMethods.length;
|
|
||||||
|
|
||||||
int len$;
|
|
||||||
for (len$ = 0; len$ < arr$; ++len$) {
|
|
||||||
Method i$ = ifcs[len$];
|
|
||||||
if (i$.getName().equals(methodName)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Class[] var8 = clazz.getInterfaces();
|
|
||||||
Class[] var9 = var8;
|
|
||||||
len$ = var8.length;
|
|
||||||
|
|
||||||
for (int var10 = 0; var10 < len$; ++var10) {
|
|
||||||
Class ifc = var9[var10];
|
|
||||||
if (hasAtLeastOneMethodWithName(ifc, methodName)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return clazz.getSuperclass() != null && hasAtLeastOneMethodWithName(clazz.getSuperclass(), methodName);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Method getMostSpecificMethod(Method method, Class<?> targetClass) {
|
|
||||||
if (method != null && isOverridable(method, targetClass) && targetClass != null
|
|
||||||
&& !targetClass.equals(method.getDeclaringClass())) {
|
|
||||||
try {
|
|
||||||
if (Modifier.isPublic(method.getModifiers())) {
|
|
||||||
try {
|
|
||||||
return targetClass.getMethod(method.getName(), method.getParameterTypes());
|
|
||||||
} catch (NoSuchMethodException var3) {
|
|
||||||
return method;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Method ex = ReflectionUtils.findMethod(targetClass, method.getName(), method.getParameterTypes());
|
|
||||||
return ex != null ? ex : method;
|
|
||||||
} catch (AccessControlException var4) {
|
|
||||||
;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return method;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean isOverridable(Method method, Class targetClass) {
|
|
||||||
return Modifier.isPrivate(method.getModifiers()) ? false
|
|
||||||
: (!Modifier.isPublic(method.getModifiers()) && !Modifier.isProtected(method.getModifiers())
|
|
||||||
? getPackageName((Class) method.getDeclaringClass())
|
|
||||||
.equals(getPackageName(targetClass))
|
|
||||||
: true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Method getStaticMethod(Class<?> clazz, String methodName, Class<?>... args) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
Assert.notNull(methodName, "Method name must not be null");
|
|
||||||
|
|
||||||
try {
|
|
||||||
Method ex = clazz.getMethod(methodName, args);
|
|
||||||
return Modifier.isStatic(ex.getModifiers()) ? ex : null;
|
|
||||||
} catch (NoSuchMethodException var4) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isPrimitiveWrapper(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
return primitiveWrapperTypeMap.containsKey(clazz);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isPrimitiveOrWrapper(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
return clazz.isPrimitive() || isPrimitiveWrapper(clazz);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isPrimitiveArray(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
return clazz.isArray() && clazz.getComponentType().isPrimitive();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isPrimitiveWrapperArray(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
return clazz.isArray() && isPrimitiveWrapper(clazz.getComponentType());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?> resolvePrimitiveIfNecessary(Class<?> clazz) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
return clazz.isPrimitive() && clazz != Void.TYPE ? (Class) primitiveTypeToWrapperMap.get(clazz) : clazz;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isAssignable(Class<?> lhsType, Class<?> rhsType) {
|
|
||||||
Assert.notNull(lhsType, "Left-hand side opType must not be null");
|
|
||||||
Assert.notNull(rhsType, "Right-hand side opType must not be null");
|
|
||||||
if (lhsType.isAssignableFrom(rhsType)) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
Class resolvedWrapper;
|
|
||||||
if (lhsType.isPrimitive()) {
|
|
||||||
resolvedWrapper = primitiveWrapperTypeMap.get(rhsType);
|
|
||||||
if (resolvedWrapper != null && lhsType.equals(resolvedWrapper)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
resolvedWrapper = (Class) primitiveTypeToWrapperMap.get(rhsType);
|
|
||||||
if (resolvedWrapper != null && lhsType.isAssignableFrom(resolvedWrapper)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isAssignableValue(Class<?> type, Object value) {
|
|
||||||
Assert.notNull(type, "Type must not be null");
|
|
||||||
return value != null ? isAssignable(type, value.getClass()) : !type.isPrimitive();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String convertResourcePathToClassName(String resourcePath) {
|
|
||||||
Assert.notNull(resourcePath, "Resource path must not be null");
|
|
||||||
return resourcePath.replace('/', '.');
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String convertClassNameToResourcePath(String className) {
|
|
||||||
Assert.notNull(className, "Class name must not be null");
|
|
||||||
return className.replace('.', '/');
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String addResourcePathToPackagePath(Class<?> clazz, String resourceName) {
|
|
||||||
Assert.notNull(resourceName, "Resource name must not be null");
|
|
||||||
return !resourceName.startsWith("/") ? classPackageAsResourcePath(clazz) + "/" + resourceName
|
|
||||||
: classPackageAsResourcePath(clazz) + resourceName;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String classPackageAsResourcePath(Class<?> clazz) {
|
|
||||||
if (clazz == null) {
|
|
||||||
return "";
|
|
||||||
} else {
|
|
||||||
String className = clazz.getName();
|
|
||||||
int packageEndIndex = className.lastIndexOf(46);
|
|
||||||
if (packageEndIndex == -1) {
|
|
||||||
return "";
|
|
||||||
} else {
|
|
||||||
String packageName = className.substring(0, packageEndIndex);
|
|
||||||
return packageName.replace('.', '/');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String classNamesToString(Class... classes) {
|
|
||||||
return classNamesToString((Collection) Arrays.asList(classes));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String classNamesToString(Collection<Class> classes) {
|
|
||||||
if (CollectionUtils.isEmpty(classes)) {
|
|
||||||
return "[]";
|
|
||||||
} else {
|
|
||||||
StringBuilder sb = new StringBuilder("[");
|
|
||||||
Iterator it = classes.iterator();
|
|
||||||
|
|
||||||
while (it.hasNext()) {
|
|
||||||
Class clazz = (Class) it.next();
|
|
||||||
sb.append(clazz.getName());
|
|
||||||
if (it.hasNext()) {
|
|
||||||
sb.append(", ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.append("]");
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?>[] toClassArray(Collection<Class<?>> collection) {
|
|
||||||
return collection == null ? null : collection.toArray(new Class[collection.size()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?>[] getAllInterfaces(Object instance) {
|
|
||||||
Assert.notNull(instance, "Instance must not be null");
|
|
||||||
return getAllInterfacesForClass(instance.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?>[] getAllInterfacesForClass(Class<?> clazz) {
|
|
||||||
return getAllInterfacesForClass(clazz, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?>[] getAllInterfacesForClass(Class<?> clazz, ClassLoader classLoader) {
|
|
||||||
Set ifcs = getAllInterfacesForClassAsSet(clazz, classLoader);
|
|
||||||
return (Class[]) ifcs.toArray(new Class[ifcs.size()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Set<Class> getAllInterfacesAsSet(Object instance) {
|
|
||||||
Assert.notNull(instance, "Instance must not be null");
|
|
||||||
return getAllInterfacesForClassAsSet(instance.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Set<Class> getAllInterfacesForClassAsSet(Class clazz) {
|
|
||||||
return getAllInterfacesForClassAsSet(clazz, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Set<Class> getAllInterfacesForClassAsSet(Class clazz, ClassLoader classLoader) {
|
|
||||||
Assert.notNull(clazz, "Class must not be null");
|
|
||||||
if (clazz.isInterface() && isVisible(clazz, classLoader)) {
|
|
||||||
return Collections.singleton(clazz);
|
|
||||||
} else {
|
|
||||||
LinkedHashSet interfaces;
|
|
||||||
for (interfaces = new LinkedHashSet(); clazz != null; clazz = clazz.getSuperclass()) {
|
|
||||||
Class[] ifcs = clazz.getInterfaces();
|
|
||||||
Class[] arr$ = ifcs;
|
|
||||||
int len$ = ifcs.length;
|
|
||||||
|
|
||||||
for (int i$ = 0; i$ < len$; ++i$) {
|
|
||||||
Class ifc = arr$[i$];
|
|
||||||
interfaces.addAll(getAllInterfacesForClassAsSet(ifc, classLoader));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return interfaces;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Class<?> createCompositeInterface(Class<?>[] interfaces, ClassLoader classLoader) {
|
|
||||||
Assert.notEmpty(interfaces, "Interfaces must not be empty");
|
|
||||||
Assert.notNull(classLoader, "ClassLoader must not be null");
|
|
||||||
return Proxy.getProxyClass(classLoader, interfaces);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isVisible(Class<?> clazz, ClassLoader classLoader) {
|
|
||||||
if (classLoader == null) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
Class ex = classLoader.loadClass(clazz.getName());
|
|
||||||
return clazz == ex;
|
|
||||||
} catch (ClassNotFoundException var3) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isCglibProxy(Object object) {
|
|
||||||
return isCglibProxyClass(object.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isCglibProxyClass(Class<?> clazz) {
|
|
||||||
return clazz != null && isCglibProxyClassName(clazz.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean isCglibProxyClassName(String className) {
|
|
||||||
return className != null && className.contains("$$");
|
|
||||||
}
|
|
||||||
|
|
||||||
static {
|
|
||||||
primitiveWrapperTypeMap.put(Boolean.class, Boolean.TYPE);
|
|
||||||
primitiveWrapperTypeMap.put(Byte.class, Byte.TYPE);
|
|
||||||
primitiveWrapperTypeMap.put(Character.class, Character.TYPE);
|
|
||||||
primitiveWrapperTypeMap.put(Double.class, Double.TYPE);
|
|
||||||
primitiveWrapperTypeMap.put(Float.class, Float.TYPE);
|
|
||||||
primitiveWrapperTypeMap.put(Integer.class, Integer.TYPE);
|
|
||||||
primitiveWrapperTypeMap.put(Long.class, Long.TYPE);
|
|
||||||
primitiveWrapperTypeMap.put(Short.class, Short.TYPE);
|
|
||||||
Iterator<Map.Entry<Class<?>, Class<?>>> primitiveTypes = primitiveWrapperTypeMap.entrySet().iterator();
|
|
||||||
|
|
||||||
while (primitiveTypes.hasNext()) {
|
|
||||||
Map.Entry<?, ?> i$ = primitiveTypes.next();
|
|
||||||
primitiveTypeToWrapperMap.put(i$.getValue(), i$.getKey());
|
|
||||||
registerCommonClasses(new Class[] {(Class) i$.getKey()});
|
|
||||||
}
|
|
||||||
|
|
||||||
HashSet primitiveTypes1 = new HashSet(32);
|
|
||||||
primitiveTypes1.addAll(primitiveWrapperTypeMap.values());
|
|
||||||
primitiveTypes1.addAll(Arrays.asList(new Class[] {boolean[].class, byte[].class, char[].class, double[].class,
|
|
||||||
float[].class, int[].class, long[].class, short[].class}));
|
|
||||||
primitiveTypes1.add(Void.TYPE);
|
|
||||||
Iterator i$1 = primitiveTypes1.iterator();
|
|
||||||
|
|
||||||
while (i$1.hasNext()) {
|
|
||||||
Class primitiveType = (Class) i$1.next();
|
|
||||||
primitiveTypeNameMap.put(primitiveType.getName(), primitiveType);
|
|
||||||
}
|
|
||||||
|
|
||||||
registerCommonClasses(new Class[] {Boolean[].class, Byte[].class, Character[].class, Double[].class,
|
|
||||||
Float[].class, Integer[].class, Long[].class, Short[].class});
|
|
||||||
registerCommonClasses(new Class[] {Number.class, Number[].class, String.class, String[].class, Object.class,
|
|
||||||
Object[].class, Class.class, Class[].class});
|
|
||||||
registerCommonClasses(new Class[] {Throwable.class, Exception.class, RuntimeException.class, Error.class,
|
|
||||||
StackTraceElement.class, StackTraceElement[].class});
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -21,6 +21,7 @@ import java.sql.SQLException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
public abstract class ReflectionUtils {
|
public abstract class ReflectionUtils {
|
||||||
|
@ -398,6 +399,40 @@ public abstract class ReflectionUtils {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new instance of the specified {@link Class} by invoking
|
||||||
|
* the constructor whose argument list matches the types of the supplied
|
||||||
|
* arguments.
|
||||||
|
*
|
||||||
|
* <p>Provided class must have a public constructor.</p>
|
||||||
|
*
|
||||||
|
* @param clazz the class to instantiate; never {@code null}
|
||||||
|
* @param args the arguments to pass to the constructor, none of which may
|
||||||
|
* be {@code null}
|
||||||
|
* @return the new instance; never {@code null}
|
||||||
|
*/
|
||||||
|
public static <T> T newInstance(Class<T> clazz, Object... args) {
|
||||||
|
Objects.requireNonNull(clazz, "Class must not be null");
|
||||||
|
Objects.requireNonNull(args, "Argument array must not be null");
|
||||||
|
if (Arrays.asList(args).contains(null)) {
|
||||||
|
throw new RuntimeException("Individual arguments must not be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
Class<?>[] parameterTypes = Arrays.stream(args).map(Object::getClass).toArray(Class[]::new);
|
||||||
|
Constructor<T> constructor = clazz.getDeclaredConstructor(parameterTypes);
|
||||||
|
|
||||||
|
if (!Modifier.isPublic(constructor.getModifiers())) {
|
||||||
|
throw new IllegalArgumentException(String.format(
|
||||||
|
"Class [%s] must have public constructor in order to be instantiated.", clazz.getName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return constructor.newInstance(args);
|
||||||
|
} catch (Throwable instantiationException) {
|
||||||
|
throw new RuntimeException(instantiationException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public interface FieldFilter {
|
public interface FieldFilter {
|
||||||
boolean matches(Field var1);
|
boolean matches(Field var1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.common.io;
|
package org.nd4j.common.io;
|
||||||
|
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
import java.net.*;
|
import java.net.*;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
|
||||||
public abstract class ResourceUtils {
|
public abstract class ResourceUtils {
|
||||||
|
@ -54,7 +57,7 @@ public abstract class ResourceUtils {
|
||||||
Assert.notNull(resourceLocation, "Resource location must not be null");
|
Assert.notNull(resourceLocation, "Resource location must not be null");
|
||||||
if (resourceLocation.startsWith("classpath:")) {
|
if (resourceLocation.startsWith("classpath:")) {
|
||||||
String ex = resourceLocation.substring("classpath:".length());
|
String ex = resourceLocation.substring("classpath:".length());
|
||||||
URL ex2 = ClassUtils.getDefaultClassLoader().getResource(ex);
|
URL ex2 = ND4JClassLoading.getNd4jClassloader().getResource(ex);
|
||||||
if (ex2 == null) {
|
if (ex2 == null) {
|
||||||
String description = "class path resource [" + ex + "]";
|
String description = "class path resource [" + ex + "]";
|
||||||
throw new FileNotFoundException(description + " cannot be resolved to URL because it does not exist");
|
throw new FileNotFoundException(description + " cannot be resolved to URL because it does not exist");
|
||||||
|
@ -80,7 +83,7 @@ public abstract class ResourceUtils {
|
||||||
if (resourceLocation.startsWith("classpath:")) {
|
if (resourceLocation.startsWith("classpath:")) {
|
||||||
String ex = resourceLocation.substring("classpath:".length());
|
String ex = resourceLocation.substring("classpath:".length());
|
||||||
String description = "class path resource [" + ex + "]";
|
String description = "class path resource [" + ex + "]";
|
||||||
URL url = ClassUtils.getDefaultClassLoader().getResource(ex);
|
URL url = ND4JClassLoading.getNd4jClassloader().getResource(ex);
|
||||||
if (url == null) {
|
if (url == null) {
|
||||||
throw new FileNotFoundException(description + " cannot be resolved to absolute file path "
|
throw new FileNotFoundException(description + " cannot be resolved to absolute file path "
|
||||||
+ "because it does not reside in the file system");
|
+ "because it does not reside in the file system");
|
||||||
|
@ -170,4 +173,17 @@ public abstract class ResourceUtils {
|
||||||
public static void useCachesIfNecessary(URLConnection con) {
|
public static void useCachesIfNecessary(URLConnection con) {
|
||||||
con.setUseCaches(con.getClass().getSimpleName().startsWith("JNLP"));
|
con.setUseCaches(con.getClass().getSimpleName().startsWith("JNLP"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String classPackageAsResourcePath(Class<?> clazz) {
|
||||||
|
Objects.requireNonNull(clazz);
|
||||||
|
|
||||||
|
String className = clazz.getName();
|
||||||
|
int packageEndIndex = className.lastIndexOf(46);
|
||||||
|
if (packageEndIndex == -1) {
|
||||||
|
return "";
|
||||||
|
} else {
|
||||||
|
String packageName = className.substring(0, packageEndIndex);
|
||||||
|
return packageName.replace('.', '/');
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.nd4j.common.resources;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.common.resources.strumpf.StrumpfResolver;
|
import org.nd4j.common.resources.strumpf.StrumpfResolver;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -24,15 +25,12 @@ public class Resources {
|
||||||
protected final List<Resolver> resolvers;
|
protected final List<Resolver> resolvers;
|
||||||
|
|
||||||
protected Resources() {
|
protected Resources() {
|
||||||
|
ServiceLoader<Resolver> loader = ND4JClassLoading.loadService(Resolver.class);
|
||||||
ServiceLoader<Resolver> loader = ServiceLoader.load(Resolver.class);
|
|
||||||
Iterator<Resolver> iter = loader.iterator();
|
|
||||||
|
|
||||||
resolvers = new ArrayList<>();
|
resolvers = new ArrayList<>();
|
||||||
resolvers.add(new StrumpfResolver());
|
resolvers.add(new StrumpfResolver());
|
||||||
while (iter.hasNext()) {
|
for (Resolver resolver : loader) {
|
||||||
Resolver r = iter.next();
|
resolvers.add(resolver);
|
||||||
resolvers.add(r);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//Sort resolvers by priority: check resolvers with lower numbers first
|
//Sort resolvers by priority: check resolvers with lower numbers first
|
||||||
|
@ -42,8 +40,6 @@ public class Resources {
|
||||||
return Integer.compare(r1.priority(), r2.priority());
|
return Integer.compare(r1.priority(), r2.priority());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1,122 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.common.util;
|
|
||||||
|
|
||||||
|
|
||||||
import java.io.PrintWriter;
|
|
||||||
import java.lang.management.ManagementFactory;
|
|
||||||
import java.lang.management.ThreadInfo;
|
|
||||||
import java.lang.management.ThreadMXBean;
|
|
||||||
import java.lang.reflect.Constructor;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* General reflection utils
|
|
||||||
*/
|
|
||||||
public class ReflectionUtils {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cache of constructors for each class. Pins the classes so they
|
|
||||||
* can't be garbage collected until ReflectionUtils can be collected.
|
|
||||||
*/
|
|
||||||
protected static final Map<Class<?>, Constructor<?>> CONSTRUCTOR_CACHE = new ConcurrentHashMap<>();
|
|
||||||
|
|
||||||
protected ReflectionUtils() {}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static private ThreadMXBean threadBean = ManagementFactory.getThreadMXBean();
|
|
||||||
|
|
||||||
public static void setContentionTracing(boolean val) {
|
|
||||||
threadBean.setThreadContentionMonitoringEnabled(val);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String getTaskName(long id, String name) {
|
|
||||||
if (name == null) {
|
|
||||||
return Long.toString(id);
|
|
||||||
}
|
|
||||||
return id + " (" + name + ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Print all of the thread's information and stack traces.
|
|
||||||
*
|
|
||||||
* @param stream the stream to
|
|
||||||
* @param title a string title for the stack trace
|
|
||||||
*/
|
|
||||||
public static void printThreadInfo(PrintWriter stream, String title) {
|
|
||||||
final int STACK_DEPTH = 20;
|
|
||||||
boolean contention = threadBean.isThreadContentionMonitoringEnabled();
|
|
||||||
long[] threadIds = threadBean.getAllThreadIds();
|
|
||||||
stream.println("Process Thread Dump: " + title);
|
|
||||||
stream.println(threadIds.length + " active threads");
|
|
||||||
for (long tid : threadIds) {
|
|
||||||
ThreadInfo info = threadBean.getThreadInfo(tid, STACK_DEPTH);
|
|
||||||
if (info == null) {
|
|
||||||
stream.println(" Inactive");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
stream.println("Thread " + getTaskName(info.getThreadId(), info.getThreadName()) + ":");
|
|
||||||
Thread.State state = info.getThreadState();
|
|
||||||
stream.println(" State: " + state);
|
|
||||||
stream.println(" Blocked count: " + info.getBlockedCount());
|
|
||||||
stream.println(" Waited count: " + info.getWaitedCount());
|
|
||||||
if (contention) {
|
|
||||||
stream.println(" Blocked time: " + info.getBlockedTime());
|
|
||||||
stream.println(" Waited time: " + info.getWaitedTime());
|
|
||||||
}
|
|
||||||
if (state == Thread.State.WAITING) {
|
|
||||||
stream.println(" Waiting on " + info.getLockName());
|
|
||||||
} else if (state == Thread.State.BLOCKED) {
|
|
||||||
stream.println(" Blocked on " + info.getLockName());
|
|
||||||
stream.println(" Blocked by " + getTaskName(info.getLockOwnerId(), info.getLockOwnerName()));
|
|
||||||
}
|
|
||||||
stream.println(" Stack:");
|
|
||||||
for (StackTraceElement frame : info.getStackTrace()) {
|
|
||||||
stream.println(" " + frame.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stream.flush();
|
|
||||||
}
|
|
||||||
|
|
||||||
private static long previousLogTime = 0;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the correctly-typed {@link Class} of the given object.
|
|
||||||
*
|
|
||||||
* @param o object whose correctly-typed <code>Class</code> is to be obtained
|
|
||||||
* @return the correctly typed <code>Class</code> of the given object.
|
|
||||||
*/
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public static <T> Class<T> getClass(T o) {
|
|
||||||
return (Class<T>) o.getClass();
|
|
||||||
}
|
|
||||||
|
|
||||||
// methods to support testing
|
|
||||||
static void clearCache() {
|
|
||||||
CONSTRUCTOR_CACHE.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
static int getCacheSize() {
|
|
||||||
return CONSTRUCTOR_CACHE.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -17,11 +17,13 @@
|
||||||
package org.nd4j.jdbc.driverfinder;
|
package org.nd4j.jdbc.driverfinder;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.sql.Driver;
|
import java.sql.Driver;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
import java.util.ServiceLoader;
|
import java.util.ServiceLoader;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
@ -45,22 +47,19 @@ public class DriverFinder {
|
||||||
discoverDriverClazz();
|
discoverDriverClazz();
|
||||||
try {
|
try {
|
||||||
driver = clazz.newInstance();
|
driver = clazz.newInstance();
|
||||||
} catch (InstantiationException e) {
|
} catch (InstantiationException | IllegalAccessException e) {
|
||||||
log.error("",e);
|
|
||||||
} catch (IllegalAccessException e) {
|
|
||||||
log.error("",e);
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return driver;
|
return driver;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static void discoverDriverClazz() {
|
private static void discoverDriverClazz() {
|
||||||
//All JDBC4 compliant drivers support ServiceLoader mechanism for discovery - https://stackoverflow.com/a/18297412
|
//All JDBC4 compliant drivers support ServiceLoader mechanism for discovery - https://stackoverflow.com/a/18297412
|
||||||
ServiceLoader<Driver> drivers = ServiceLoader.load(Driver.class);
|
ServiceLoader<Driver> drivers = ND4JClassLoading.loadService(Driver.class);
|
||||||
Set<Class<? extends Driver>> driverClasses = new HashSet<>();
|
Set<Class<? extends Driver>> driverClasses = new HashSet<>();
|
||||||
for(Driver d : drivers){
|
for(Driver driver : drivers){
|
||||||
driverClasses.add(d.getClass());
|
driverClasses.add(driver.getClass());
|
||||||
}
|
}
|
||||||
|
|
||||||
if(driverClasses.isEmpty()){
|
if(driverClasses.isEmpty()){
|
||||||
|
@ -79,16 +78,13 @@ public class DriverFinder {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
String clazz = props.getProperty(JDBC_KEY);
|
String jdbcKeyClassName = props.getProperty(JDBC_KEY);
|
||||||
if (clazz == null)
|
Objects.requireNonNull(jdbcKeyClassName, "Unable to find jdbc driver. Please specify a "
|
||||||
throw new IllegalStateException("Unable to find jdbc driver. Please specify a "
|
+ ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY);
|
||||||
+ ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY);
|
|
||||||
try {
|
DriverFinder.clazz = ND4JClassLoading.loadClassByName(jdbcKeyClassName);
|
||||||
DriverFinder.clazz = (Class<? extends Driver>) Class.forName(clazz);
|
Objects.requireNonNull(DriverFinder.clazz, "Unable to find jdbc driver. Please specify a "
|
||||||
} catch (ClassNotFoundException e) {
|
+ ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY);
|
||||||
throw new IllegalStateException("Unable to find jdbc driver. Please specify a "
|
|
||||||
+ ND4j_JDBC_PROPERTIES + " with the key " + JDBC_KEY);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.hsqldb.jdbc.JDBCDataSource;
|
||||||
import org.junit.AfterClass;
|
import org.junit.AfterClass;
|
||||||
import org.junit.BeforeClass;
|
import org.junit.BeforeClass;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -46,13 +47,12 @@ public class HSqlLoaderTest extends BaseND4JTest {
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
public static void init() throws Exception {
|
public static void init() throws Exception {
|
||||||
hsqlLoader = new HsqlLoader(dataSource(),JDBC_URL,TABLE_NAME,ID_COLUMN_NAME,COLUMN_NAME);
|
hsqlLoader = new HsqlLoader(dataSource(),JDBC_URL,TABLE_NAME,ID_COLUMN_NAME,COLUMN_NAME);
|
||||||
Class.forName("org.hsqldb.jdbc.JDBCDriver");
|
ND4JClassLoading.loadClassByName("org.hsqldb.jdbc.JDBCDriver");
|
||||||
|
|
||||||
// initialize database
|
// initialize database
|
||||||
initDatabase();
|
initDatabase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static DataSource dataSource() {
|
public static DataSource dataSource() {
|
||||||
if (dataSource != null)
|
if (dataSource != null)
|
||||||
return dataSource;
|
return dataSource;
|
||||||
|
@ -65,8 +65,6 @@ public class HSqlLoaderTest extends BaseND4JTest {
|
||||||
return dataSource;
|
return dataSource;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@AfterClass
|
@AfterClass
|
||||||
public static void destroy() throws SQLException {
|
public static void destroy() throws SQLException {
|
||||||
try (Connection connection = getConnection(); Statement statement = connection.createStatement()) {
|
try (Connection connection = getConnection(); Statement statement = connection.createStatement()) {
|
||||||
|
@ -131,6 +129,4 @@ public class HSqlLoaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,10 +18,11 @@ package org.nd4j.parameterserver.distributed.messages;
|
||||||
|
|
||||||
import org.agrona.concurrent.UnsafeBuffer;
|
import org.agrona.concurrent.UnsafeBuffer;
|
||||||
import org.apache.commons.io.input.ClassLoaderObjectInputStream;
|
import org.apache.commons.io.input.ClassLoaderObjectInputStream;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
||||||
import org.nd4j.parameterserver.distributed.enums.NodeRole;
|
import org.nd4j.parameterserver.distributed.enums.NodeRole;
|
||||||
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
|
|
||||||
import org.nd4j.parameterserver.distributed.logic.Storage;
|
import org.nd4j.parameterserver.distributed.logic.Storage;
|
||||||
|
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
|
||||||
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
|
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
|
||||||
import org.nd4j.parameterserver.distributed.transport.Transport;
|
import org.nd4j.parameterserver.distributed.transport.Transport;
|
||||||
|
|
||||||
|
@ -51,17 +52,16 @@ public interface VoidMessage extends Serializable {
|
||||||
|
|
||||||
UnsafeBuffer asUnsafeBuffer();
|
UnsafeBuffer asUnsafeBuffer();
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
static <T extends VoidMessage> T fromBytes(byte[] array) {
|
static <T extends VoidMessage> T fromBytes(byte[] array) {
|
||||||
try {
|
ClassLoader classloader = ND4JClassLoading.getNd4jClassloader();
|
||||||
ObjectInputStream in = new ClassLoaderObjectInputStream(Thread.currentThread().getContextClassLoader(),
|
|
||||||
new ByteArrayInputStream(array));
|
|
||||||
|
|
||||||
T result = (T) in.readObject();
|
try (ByteArrayInputStream bis = new ByteArrayInputStream(array);
|
||||||
return result;
|
ObjectInputStream ois = new ClassLoaderObjectInputStream(classloader, bis)) {
|
||||||
} catch (Exception e) {
|
return (T) ois.readObject();
|
||||||
throw new RuntimeException(e);
|
} catch (Exception objectReadException) {
|
||||||
|
throw new RuntimeException(objectReadException);
|
||||||
}
|
}
|
||||||
//return SerializationUtils.deserialize(array);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.parameterserver.distributed.training;
|
package org.nd4j.parameterserver.distributed.training;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
||||||
import org.nd4j.parameterserver.distributed.logic.Storage;
|
import org.nd4j.parameterserver.distributed.logic.Storage;
|
||||||
|
@ -52,13 +53,14 @@ public class TrainerProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void loadProviders(){
|
protected void loadProviders(){
|
||||||
ServiceLoader<TrainingDriver> serviceLoader = ServiceLoader.load(TrainingDriver.class);
|
ServiceLoader<TrainingDriver> serviceLoader = ND4JClassLoading.loadService(TrainingDriver.class);
|
||||||
for(TrainingDriver d : serviceLoader){
|
for (TrainingDriver trainingDriver : serviceLoader){
|
||||||
trainers.put(d.targetMessageClass(), d);
|
trainers.put(trainingDriver.targetMessageClass(), trainingDriver);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (trainers.size() < 1)
|
if (trainers.isEmpty()) {
|
||||||
throw new ND4JIllegalStateException("No TrainingDrivers were found via ServiceLoader mechanism");
|
throw new ND4JIllegalStateException("No TrainingDrivers were found via ServiceLoader mechanism");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport,
|
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport,
|
||||||
|
@ -73,8 +75,6 @@ public class TrainerProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
protected <T extends TrainingMessage> TrainingDriver<T> getTrainer(T message) {
|
protected <T extends TrainingMessage> TrainingDriver<T> getTrainer(T message) {
|
||||||
TrainingDriver<?> driver = trainers.get(message.getClass().getSimpleName());
|
TrainingDriver<?> driver = trainers.get(message.getClass().getSimpleName());
|
||||||
|
|
|
@ -21,6 +21,8 @@ import com.beust.jcommander.Parameter;
|
||||||
import com.beust.jcommander.ParameterException;
|
import com.beust.jcommander.ParameterException;
|
||||||
import com.beust.jcommander.Parameters;
|
import com.beust.jcommander.Parameters;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
|
import org.nd4j.common.io.ReflectionUtils;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
@ -290,12 +292,10 @@ public class ParameterServerSubscriber implements AutoCloseable {
|
||||||
case TIME_DELAYED:
|
case TIME_DELAYED:
|
||||||
break;
|
break;
|
||||||
case CUSTOM:
|
case CUSTOM:
|
||||||
try {
|
String parameterServerUpdateType = System.getProperty(CUSTOM_UPDATE_TYPE);
|
||||||
updater = (ParameterServerUpdater) Class.forName(System.getProperty(CUSTOM_UPDATE_TYPE))
|
Class<ParameterServerUpdater> updaterClass = ND4JClassLoading
|
||||||
.newInstance();
|
.loadClassByName(parameterServerUpdateType);
|
||||||
} catch (Exception e) {
|
updater = ReflectionUtils.newInstance(updaterClass);
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalStateException("Illegal opType of updater");
|
throw new IllegalStateException("Illegal opType of updater");
|
||||||
|
|
Loading…
Reference in New Issue