FEATURE#8712: add possibility to specify classloader for DL4J (#9115)

Signed-off-by: hosuaby <alexei.klenin@gmail.com>
master
Alexei KLENIN 2020-10-29 06:38:42 -07:00 committed by GitHub
parent 2e000c84ac
commit a722bd5a5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 744 additions and 461 deletions

View File

@ -33,6 +33,12 @@
<artifactId>nd4j-common</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
@ -43,5 +49,4 @@
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) Eclipse Deeplearning4j Contributors 2020
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.common.config;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.config.ND4JClassLoading;
import java.lang.reflect.InvocationTargetException;
import java.util.Objects;
import java.util.ServiceLoader;
/**
* Global context for class-loading in DL4J.
* <p>Use {@code DL4JClassLoading} to define classloader for Deeplearning4j only! To define classloader used by
* {@code ND4J} use class {@link org.nd4j.common.config.ND4JClassLoading}.
*
* <p>Usage:
* <pre>{@code
* public class Application {
* static {
* DL4JClassLoading.setDl4jClassloaderFromClass(Application.class);
* }
*
* public static void main(String[] args) {
* }
* }
* }</code>
*
* @see org.nd4j.common.config.ND4JClassLoading
*
* @author Alexei KLENIN
*/
@Slf4j
public class DL4JClassLoading {
private static ClassLoader dl4jClassloader = ND4JClassLoading.getNd4jClassloader();
private DL4JClassLoading() {
}
public static ClassLoader getDl4jClassloader() {
return DL4JClassLoading.dl4jClassloader;
}
public static void setDl4jClassloaderFromClass(Class<?> clazz) {
setDl4jClassloader(clazz.getClassLoader());
}
public static void setDl4jClassloader(ClassLoader dl4jClassloader) {
DL4JClassLoading.dl4jClassloader = dl4jClassloader;
log.debug("Global class-loader for DL4J was changed.");
}
public static boolean classPresentOnClasspath(String className) {
return classPresentOnClasspath(className, dl4jClassloader);
}
public static boolean classPresentOnClasspath(String className, ClassLoader classLoader) {
return loadClassByName(className, false, classLoader) != null;
}
public static <T> Class<T> loadClassByName(String className) {
return loadClassByName(className, true, dl4jClassloader);
}
@SuppressWarnings("unchecked")
public static <T> Class<T> loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
try {
return (Class<T>) Class.forName(className, initialize, classLoader);
} catch (ClassNotFoundException classNotFoundException) {
log.error(String.format("Cannot find class [%s] of provided class-loader.", className));
return null;
}
}
public static <T> T createNewInstance(String className, Object... args) {
return createNewInstance(className, Object.class, args);
}
public static <T> T createNewInstance(String className, Class<? super T> superclass) {
return createNewInstance(className, superclass, new Class<?>[]{}, new Object[]{});
}
public static <T> T createNewInstance(String className, Class<? super T> superclass, Object... args) {
Class<?>[] parameterTypes = new Class<?>[args.length];
for (int i = 0; i < args.length; i++) {
Object arg = args[i];
Objects.requireNonNull(arg);
parameterTypes[i] = arg.getClass();
}
return createNewInstance(className, superclass, parameterTypes, args);
}
public static <T> T createNewInstance(
String className,
Class<? super T> superclass,
Class<?>[] parameterTypes,
Object... args) {
try {
return (T) DL4JClassLoading
.loadClassByName(className)
.asSubclass(superclass)
.getDeclaredConstructor(parameterTypes)
.newInstance(args);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException
| NoSuchMethodException instantiationException) {
log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
throw new RuntimeException(instantiationException);
}
}
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass) {
return loadService(serviceClass, dl4jClassloader);
}
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass, ClassLoader classLoader) {
return ServiceLoader.load(serviceClass, classLoader);
}
}

View File

@ -0,0 +1,67 @@
package org.deeplearning4j.common.config;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import org.deeplearning4j.common.config.dummies.TestAbstract;
import org.junit.Test;
public class DL4JClassLoadingTest {
private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies.";
@Test
public void testCreateNewInstance_constructorWithoutArguments() {
/* Given */
String className = PACKAGE_PREFIX + "TestDummy";
/* When */
Object instance = DL4JClassLoading.createNewInstance(className);
/* Then */
assertNotNull(instance);
assertEquals(className, instance.getClass().getName());
}
@Test
public void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() {
/* Given */
String className = PACKAGE_PREFIX + "TestColor";
/* When */
TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white");
/* Then */
assertNotNull(instance);
assertEquals(className, instance.getClass().getName());
}
@Test
public void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() {
/* Given */
String colorClassName = PACKAGE_PREFIX + "TestColor";
String rectangleClassName = PACKAGE_PREFIX + "TestRectangle";
/* When */
TestAbstract color = DL4JClassLoading.createNewInstance(
colorClassName,
Object.class,
new Class<?>[]{ int.class, int.class, int.class },
45, 175, 200);
TestAbstract rectangle = DL4JClassLoading.createNewInstance(
rectangleClassName,
Object.class,
new Class<?>[]{ int.class, int.class, TestAbstract.class },
10, 15, color);
/* Then */
assertNotNull(color);
assertEquals(colorClassName, color.getClass().getName());
assertNotNull(rectangle);
assertEquals(rectangleClassName, rectangle.getClass().getName());
}
}

View File

@ -0,0 +1,4 @@
package org.deeplearning4j.common.config.dummies;
public abstract class TestAbstract {
}

View File

@ -0,0 +1,9 @@
package org.deeplearning4j.common.config.dummies;
public class TestColor extends TestAbstract {
public TestColor(String color) {
}
public TestColor(int r, int g, int b) {
}
}

View File

@ -0,0 +1,6 @@
package org.deeplearning4j.common.config.dummies;
public class TestDummy {
public TestDummy() {
}
}

View File

@ -0,0 +1,6 @@
package org.deeplearning4j.common.config.dummies;
public class TestRectangle extends TestAbstract {
public TestRectangle(int width, int height, TestAbstract color) {
}
}

View File

@ -102,7 +102,6 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
}
}
private class ReaderThread<T> extends Thread implements Runnable {
private BlockingQueue<T> buffer;
private Iterator<T> iterator;

View File

@ -19,6 +19,7 @@ package org.deeplearning4j;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
@ -66,23 +67,23 @@ public class LayerHelperValidationUtil {
public static void disableCppHelpers(){
try {
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method m = c.getMethod("getInstance");
Object instance = m.invoke(null);
Method m2 = c.getMethod("allowHelpers", boolean.class);
m2.invoke(instance, false);
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method getInstance = clazz.getMethod("getInstance");
Object instance = getInstance.invoke(null);
Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
allowHelpers.invoke(instance, false);
} catch (Throwable t){
throw new RuntimeException(t);
}
}
public static void enableCppHelpers(){
try{
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method m = c.getMethod("getInstance");
Object instance = m.invoke(null);
Method m2 = c.getMethod("allowHelpers", boolean.class);
m2.invoke(instance, true);
try {
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method getInstance = clazz.getMethod("getInstance");
Object instance = getInstance.invoke(null);
Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
allowHelpers.invoke(instance, true);
} catch (Throwable t){
throw new RuntimeException(t);
}

View File

@ -16,28 +16,96 @@
package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.dropout.AlphaDropout;
import org.deeplearning4j.nn.conf.dropout.GaussianDropout;
import org.deeplearning4j.nn.conf.dropout.GaussianNoise;
import org.deeplearning4j.nn.conf.dropout.SpatialDropout;
import org.deeplearning4j.nn.conf.graph.*;
import org.deeplearning4j.nn.conf.graph.AttentionVertex;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.FrozenVertex;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.L2NormalizeVertex;
import org.deeplearning4j.nn.conf.graph.L2Vertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.PoolHelperVertex;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.graph.ReshapeVertex;
import org.deeplearning4j.nn.conf.graph.ScaleVertex;
import org.deeplearning4j.nn.conf.graph.ShiftVertex;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
import org.deeplearning4j.nn.conf.graph.UnstackVertex;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer;
import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer;
import org.deeplearning4j.nn.conf.layers.Cnn3DLossLayer;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.conf.layers.Convolution2D;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LearnedSelfAttentionLayer;
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PReLULayer;
import org.deeplearning4j.nn.conf.layers.Pooling1D;
import org.deeplearning4j.nn.conf.layers.Pooling2D;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.conf.layers.RecurrentAttentionLayer;
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.SelfAttentionLayer;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.conf.layers.SpaceToBatchLayer;
import org.deeplearning4j.nn.conf.layers.SpaceToDepthLayer;
import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer;
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.Upsampling1D;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.Upsampling3D;
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
@ -49,16 +117,24 @@ import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
@ -77,12 +153,17 @@ import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import java.io.IOException;
import java.lang.reflect.Modifier;
import java.util.*;
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@Slf4j
public class DTypeTests extends BaseDL4JTest {
@ -120,20 +201,17 @@ public class DTypeTests extends BaseDL4JTest {
Set<Class<?>> preprocClasses = new HashSet<>();
Set<Class<?>> vertexClasses = new HashSet<>();
for (ClassPath.ClassInfo ci : info) {
Class<?> clazz;
try {
clazz = Class.forName(ci.getName());
} catch (ClassNotFoundException e) {
//Should never happen as this was found on the classpath
throw new RuntimeException(e);
}
Class<?> clazz = DL4JClassLoading.loadClassByName(ci.getName());
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) {
// Skip TFOpLayer here - dtype depends on imported model dtype
continue;
}
if (clazz.getName().toLowerCase().contains("custom") || clazz.getName().contains("samediff.testlayers")
|| clazz.getName().toLowerCase().contains("test") || ignoreClasses.contains(clazz)) {
if (clazz.getName().toLowerCase().contains("custom")
|| clazz.getName().contains("samediff.testlayers")
|| clazz.getName().toLowerCase().contains("test")
|| ignoreClasses.contains(clazz)) {
continue;
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.recurrent;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -167,7 +168,7 @@ public class GravesLSTMTest extends BaseDL4JTest {
actHelper.setAccessible(true);
//Call activateHelper with both forBackprop == true, and forBackprop == false and compare
Class<?> innerClass = Class.forName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
Class<?> innerClass = DL4JClassLoading.loadClassByName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
Object oFalse = actHelper.invoke(lstm, false, null, null, false, LayerWorkspaceMgr.noWorkspacesImmutable()); //GravesLSTM.FwdPassReturn object; want fwdPassOutput INDArray
Object oTrue = actHelper.invoke(lstm, false, null, null, true, LayerWorkspaceMgr.noWorkspacesImmutable()); //want fwdPassOutputAsArrays object

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
@ -110,7 +111,7 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
byte[] graphBytes = serialized.toByteArray();
ServiceLoader<TFGraphRunnerService> sl = ServiceLoader.load(TFGraphRunnerService.class);
ServiceLoader<TFGraphRunnerService> sl = DL4JClassLoading.loadService(TFGraphRunnerService.class);
Iterator<TFGraphRunnerService> iter = sl.iterator();
if (!iter.hasNext()){
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");

View File

@ -61,14 +61,12 @@ public abstract class Model {
/**
*
*
* @param path
* @return
* @return
* @throws Exception
*/
public static Model load(Class<? extends Model> c, InputStream is) throws Exception {
Model model = c.newInstance();
return model.loadModel(is);
public static Model load(Class<? extends Model> modelClass, InputStream inputStream) throws Exception {
return modelClass
.getDeclaredConstructor()
.newInstance()
.loadModel(inputStream);
}
/**

View File

@ -5,6 +5,7 @@ import org.ansj.dic.impl.Jar2Stream;
import org.ansj.dic.impl.Jdbc2Stream;
import org.ansj.dic.impl.Url2Stream;
import org.ansj.exception.LibraryException;
import org.deeplearning4j.common.config.DL4JClassLoading;
import java.io.InputStream;
@ -25,7 +26,8 @@ public abstract class PathToStream {
} else if (path.startsWith("jar://")) {
return new Jar2Stream().toStream(path);
} else if (path.startsWith("class://")) {
((PathToStream) Class.forName(path.substring(8).split("\\|")[0]).newInstance()).toStream(path);
// Probably unused
return loadClass(path);
} else if (path.startsWith("http://") || path.startsWith("https://")) {
return new Url2Stream().toStream(path);
} else {
@ -34,9 +36,17 @@ public abstract class PathToStream {
} catch (Exception e) {
throw new LibraryException(e);
}
throw new LibraryException("not find method type in path " + path);
}
public abstract InputStream toStream(String path);
static InputStream loadClass(String path) {
String className = path
.substring("class://".length())
.split("\\|")[0];
return DL4JClassLoading
.createNewInstance(className, PathToStream.class)
.toStream(path);
}
}

View File

@ -3,6 +3,7 @@ package org.ansj.dic.impl;
import org.ansj.dic.DicReader;
import org.ansj.dic.PathToStream;
import org.ansj.exception.LibraryException;
import org.deeplearning4j.common.config.DL4JClassLoading;
import java.io.InputStream;
@ -17,12 +18,16 @@ public class Jar2Stream extends PathToStream {
@Override
public InputStream toStream(String path) {
if (path.contains("|")) {
String[] split = path.split("\\|");
try {
return Class.forName(split[0].substring(6)).getResourceAsStream(split[1].trim());
} catch (ClassNotFoundException e) {
throw new LibraryException(e);
String[] tokens = path.split("\\|");
String className = tokens[0].substring(6);
String resourceName = tokens[1].trim();
Class<Object> resourceClass = DL4JClassLoading.loadClassByName(className);
if (resourceClass == null) {
throw new LibraryException(String.format("Class '%s' was not found.", className));
}
return resourceClass.getResourceAsStream(resourceName);
} else {
return DicReader.getInputStream(path.substring(6));
}

View File

@ -2,6 +2,7 @@ package org.ansj.dic.impl;
import org.ansj.dic.PathToStream;
import org.ansj.exception.LibraryException;
import org.deeplearning4j.common.config.DL4JClassLoading;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@ -22,20 +23,26 @@ public class Jdbc2Stream extends PathToStream {
private static final byte[] LINE = "\n".getBytes();
private static final String[] JDBC_DRIVERS = {
"org.h2.Driver",
"com.ibm.db2.jcc.DB2Driver",
"org.hsqldb.jdbcDriver",
"org.gjt.mm.mysql.Driver",
"oracle.jdbc.OracleDriver",
"org.postgresql.Driver",
"net.sourceforge.jtds.jdbc.Driver",
"com.microsoft.sqlserver.jdbc.SQLServerDriver",
"org.sqlite.JDBC",
"com.mysql.jdbc.Driver"
};
static {
String[] drivers = {"org.h2.Driver", "com.ibm.db2.jcc.DB2Driver", "org.hsqldb.jdbcDriver",
"org.gjt.mm.mysql.Driver", "oracle.jdbc.OracleDriver", "org.postgresql.Driver",
"net.sourceforge.jtds.jdbc.Driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver",
"org.sqlite.JDBC", "com.mysql.jdbc.Driver"};
for (String driverClassName : drivers) {
try {
try {
Thread.currentThread().getContextClassLoader().loadClass(driverClassName);
} catch (ClassNotFoundException e) {
Class.forName(driverClassName);
}
} catch (Throwable e) {
loadJdbcDrivers();
}
static void loadJdbcDrivers() {
for (String driverClassName : JDBC_DRIVERS) {
DL4JClassLoading.loadClassByName(driverClassName);
}
}

View File

@ -17,44 +17,20 @@
package org.deeplearning4j.models.embeddings.loader;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.GZIPInputStream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.compress.compressors.gzip.GzipUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.common.util.DL4JFileUtils;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
@ -94,12 +70,37 @@ import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.storage.CompressedRamStorage;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.GZIPInputStream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
/**
* This is utility class, providing various methods for WordVectors serialization
@ -2676,26 +2677,23 @@ public class WordVectorSerializer {
}
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
if (configuration == null)
if (configuration == null) {
return null;
}
if (configuration.getTokenizerFactory() != null && !configuration.getTokenizerFactory().isEmpty()) {
try {
TokenizerFactory factory =
(TokenizerFactory) Class.forName(configuration.getTokenizerFactory()).newInstance();
String tokenizerFactoryClassName = configuration.getTokenizerFactory();
if (StringUtils.isNotEmpty(tokenizerFactoryClassName)) {
TokenizerFactory factory = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
if (configuration.getTokenPreProcessor() != null && !configuration.getTokenPreProcessor().isEmpty()) {
TokenPreProcess preProcessor =
(TokenPreProcess) Class.forName(configuration.getTokenPreProcessor()).newInstance();
String tokenPreProcessorClassName = configuration.getTokenPreProcessor();
if (StringUtils.isNotEmpty(tokenPreProcessorClassName)) {
TokenPreProcess preProcessor = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
factory.setTokenPreProcessor(preProcessor);
}
return factory;
}
} catch (Exception e) {
log.error("Can't instantiate saved TokenizerFactory: {}", configuration.getTokenizerFactory());
}
}
return null;
}

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.models.sequencevectors;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.Getter;
@ -494,15 +496,17 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
this.useHierarchicSoftmax = configuration.isUseHierarchicSoftmax();
this.preciseMode = configuration.isPreciseMode();
if (configuration.getModelUtils() != null && !configuration.getModelUtils().isEmpty()) {
String modelUtilsClassName = configuration.getModelUtils();
if (StringUtils.isNotEmpty(modelUtilsClassName)) {
try {
this.modelUtils = (ModelUtils<T>) Class.forName(configuration.getModelUtils()).newInstance();
} catch (Exception e) {
log.error("Got {} trying to instantiate ModelUtils, falling back to BasicModelUtils instead");
this.modelUtils = DL4JClassLoading.createNewInstance(modelUtilsClassName);
} catch (Exception instantiationException) {
log.error(
"Got '{}' trying to instantiate ModelUtils, falling back to BasicModelUtils instead",
instantiationException.getMessage(),
instantiationException);
this.modelUtils = new BasicModelUtils<>();
}
}
if (configuration.getElementsLearningAlgorithm() != null
@ -551,12 +555,7 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
* @return
*/
public Builder<T> sequenceLearningAlgorithm(@NonNull String algoName) {
try {
Class clazz = Class.forName(algoName);
sequenceLearningAlgorithm = (SequenceLearningAlgorithm<T>) clazz.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
return this;
}
@ -578,13 +577,9 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
* @return
*/
public Builder<T> elementsLearningAlgorithm(@NonNull String algoName) {
try {
Class clazz = Class.forName(algoName);
elementsLearningAlgorithm = (ElementsLearningAlgorithm<T>) clazz.newInstance();
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
this.configuration.setElementsLearningAlgorithm(elementsLearningAlgorithm.getClass().getCanonicalName());
} catch (Exception e) {
throw new RuntimeException(e);
}
return this;
}
@ -943,31 +938,23 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
.lr(learningRate).seed(seed).build();
}
if (this.configuration.getElementsLearningAlgorithm() != null) {
try {
elementsLearningAlgorithm = (ElementsLearningAlgorithm<T>) Class
.forName(this.configuration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
String elementsLearningAlgorithm = this.configuration.getElementsLearningAlgorithm();
if (StringUtils.isNotEmpty(elementsLearningAlgorithm)) {
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
if (this.configuration.getSequenceLearningAlgorithm() != null) {
try {
sequenceLearningAlgorithm = (SequenceLearningAlgorithm<T>) Class
.forName(this.configuration.getSequenceLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
String sequenceLearningAlgorithm = this.configuration.getSequenceLearningAlgorithm();
if (StringUtils.isNotEmpty(sequenceLearningAlgorithm)) {
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
}
if (trainElementsVectors && elementsLearningAlgorithm == null) {
if (trainElementsVectors && this.elementsLearningAlgorithm == null) {
// create default implementation of ElementsLearningAlgorithm
elementsLearningAlgorithm = new SkipGram<>();
this.elementsLearningAlgorithm = new SkipGram<>();
}
if (trainSequenceVectors && sequenceLearningAlgorithm == null) {
sequenceLearningAlgorithm = new DBOW<>();
if (trainSequenceVectors && this.sequenceLearningAlgorithm == null) {
this.sequenceLearningAlgorithm = new DBOW<>();
}
this.modelUtils.init(lookupTable);

View File

@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
@ -132,25 +133,20 @@ public class Dropout implements IDropout {
*/
protected void initializeHelper(DataType dataType){
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
if("CUDA".equalsIgnoreCase(backend)) {
try {
helper = Class.forName("org.deeplearning4j.cuda.dropout.CudnnDropoutHelper")
.asSubclass(DropoutHelper.class).getConstructor(DataType.class).newInstance(dataType);
helper = DL4JClassLoading.createNewInstance(
"org.deeplearning4j.cuda.dropout.CudnnDropoutHelper",
DropoutHelper.class,
dataType);
log.debug("CudnnDropoutHelper successfully initialized");
if (!helper.checkSupported()) {
helper = null;
}
} catch (Throwable t) {
if (!(t instanceof ClassNotFoundException)) {
log.warn("Could not initialize CudnnDropoutHelper", t);
}
//Unlike other layers, don't warn here about CuDNN not found - if the user has any other layers that can
// benefit from them cudnn, they will get a warning from those
}
}
initializedHelper = true;
}
initializedHelper = true;
}
@Override
public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {

View File

@ -43,6 +43,7 @@ import org.nd4j.shade.jackson.databind.deser.std.StdDeserializer;
import org.nd4j.shade.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
@ -268,14 +269,17 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
//Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" : <object>
protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){
if(baseLayer.getActivationFn() == null && on.has("activationFunction")){
String afn = on.get("activationFunction").asText();
IActivation a = null;
try {
a = getMap().get(afn.toLowerCase()).newInstance();
} catch (InstantiationException | IllegalAccessException e){
//Ignore
a = getMap()
.get(afn.toLowerCase())
.getDeclaredConstructor()
.newInstance();
} catch (InstantiationException | IllegalAccessException | NoSuchMethodException
| InvocationTargetException instantiationException){
log.error(instantiationException.getMessage());
}
baseLayer.setActivationFn(a);
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
@ -73,26 +74,19 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
void initializeHelper() {
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
if("CUDA".equalsIgnoreCase(backend)) {
try {
helper = Class.forName("org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper")
.asSubclass(ConvolutionHelper.class).getConstructor(DataType.class).newInstance(dataType);
helper = DL4JClassLoading.createNewInstance(
"org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper",
ConvolutionHelper.class,
dataType);
log.debug("CudnnConvolutionHelper successfully initialized");
if (!helper.checkSupported()) {
helper = null;
}
} catch (Throwable t) {
if (!(t instanceof ClassNotFoundException)) {
log.warn("Could not initialize CudnnConvolutionHelper", t);
} else {
OneTimeLogger.info(log, "cuDNN not found: "
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
}
}
} else if("CPU".equalsIgnoreCase(backend)){
helper = new MKLDNNConvHelper(dataType);
log.trace("Created MKLDNNConvHelper, layer {}", layerConf().getLayerName());
}
if (helper != null && !helper.checkSupported()) {
log.debug("Removed helper {} as not supported", helper.getClass());
helper = null;

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.layers.convolution.subsampling;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
@ -30,17 +31,15 @@ import org.deeplearning4j.nn.layers.mkldnn.MKLDNNSubsamplingHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.OneTimeLogger;
import java.util.Arrays;
/**
* Subsampling layer.
*
@ -64,27 +63,21 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
void initializeHelper() {
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
if("CUDA".equalsIgnoreCase(backend)) {
try {
helper = Class.forName("org.deeplearning4j.cuda.convolution.subsampling.CudnnSubsamplingHelper")
.asSubclass(SubsamplingHelper.class).getConstructor(DataType.class).newInstance(dataType);
helper = DL4JClassLoading.createNewInstance(
"org.deeplearning4j.cuda.convolution.subsampling.CudnnSubsamplingHelper",
SubsamplingHelper.class,
dataType);
log.debug("CudnnSubsamplingHelper successfully initialized");
if (!helper.checkSupported()) {
helper = null;
}
} catch (Throwable t) {
if (!(t instanceof ClassNotFoundException)) {
log.warn("Could not initialize CudnnSubsamplingHelper", t);
} else {
OneTimeLogger.info(log, "cuDNN not found: "
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
}
}
} else if("CPU".equalsIgnoreCase(backend) ){
helper = new MKLDNNSubsamplingHelper(dataType);
log.trace("Created MKL-DNN helper: MKLDNNSubsamplingHelper, layer {}", layerConf().getLayerName());
}
if (helper != null && !helper.checkSupported()) {
log.debug("Removed helper {} as not supported", helper.getClass());
helper = null;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.mkldnn;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.nd4j.linalg.factory.Nd4j;
import java.lang.reflect.Method;
@ -47,12 +48,11 @@ public class BaseMKLDNNHelper {
}
try{
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method m = c.getMethod("getInstance");
Object instance = m.invoke(null);
Method m2 = c.getMethod("isUseMKLDNN");
boolean b = (Boolean)m2.invoke(instance);
return b;
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method getInstance = clazz.getMethod("getInstance");
Object instance = getInstance.invoke(null);
Method isUseMKLDNNMethod = clazz.getMethod("isUseMKLDNN");
return (boolean) isUseMKLDNNMethod.invoke(instance);
} catch (Throwable t ){
FAILED_CHECK = new AtomicBoolean(true);
return false;

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
@ -75,24 +76,18 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
void initializeHelper() {
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
if("CUDA".equalsIgnoreCase(backend)) {
try {
helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper")
.asSubclass(BatchNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType);
if ("CUDA".equalsIgnoreCase(backend)) {
helper = DL4JClassLoading.createNewInstance(
"org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper",
BatchNormalizationHelper.class,
dataType);
log.debug("CudnnBatchNormalizationHelper successfully initialized");
} catch (Throwable t) {
if (!(t instanceof ClassNotFoundException)) {
log.warn("Could not initialize CudnnBatchNormalizationHelper", t);
} else {
OneTimeLogger.info(log, "cuDNN not found: "
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
}
}
} else if("CPU".equalsIgnoreCase(backend)){
} else if ("CPU".equalsIgnoreCase(backend)){
helper = new MKLDNNBatchNormHelper(dataType);
log.trace("Created MKLDNNBatchNormHelper, layer {}", layerConf().getLayerName());
}
if (helper != null && !helper.checkSupported(layerConf().getEps(), layerConf().isLockGammaBeta())) {
log.debug("Removed helper {} as not supported with epsilon {}, lockGammaBeta={}", helper.getClass(), layerConf().getEps(), layerConf().isLockGammaBeta());
helper = null;

View File

@ -16,7 +16,9 @@
package org.deeplearning4j.nn.layers.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -36,9 +38,6 @@ import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.common.primitives.Triple;
import org.nd4j.common.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
@ -65,10 +64,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
* <p>
* Created by nyghtowl on 10/29/15.
*/
@Slf4j
public class LocalResponseNormalization
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
protected static final Logger log =
LoggerFactory.getLogger(org.deeplearning4j.nn.conf.layers.LocalResponseNormalization.class);
protected LocalResponseNormalizationHelper helper = null;
protected int helperCountFail = 0;
@ -86,19 +84,11 @@ public class LocalResponseNormalization
void initializeHelper() {
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
if("CUDA".equalsIgnoreCase(backend)) {
try {
helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper")
.asSubclass(LocalResponseNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType);
helper = DL4JClassLoading.createNewInstance(
"org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper",
LocalResponseNormalizationHelper.class,
dataType);
log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
} catch (Throwable t) {
if (!(t instanceof ClassNotFoundException)) {
log.warn("Could not initialize CudnnLocalResponseNormalizationHelper", t);
} else {
OneTimeLogger.info(log, "cuDNN not found: "
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
}
}
}
//2019-03-09 AB - MKL-DNN helper disabled: https://github.com/deeplearning4j/deeplearning4j/issues/7272
// else if("CPU".equalsIgnoreCase(backend)){

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.layers.recurrent;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -58,22 +59,14 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
void initializeHelper() {
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
if("CUDA".equalsIgnoreCase(backend)) {
try {
helper = Class.forName("org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper")
.asSubclass(LSTMHelper.class).getConstructor(DataType.class).newInstance(dataType);
helper = DL4JClassLoading.createNewInstance(
"org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper",
LSTMHelper.class,
dataType);
log.debug("CudnnLSTMHelper successfully initialized");
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
helper = null;
}
} catch (Throwable t) {
if (!(t instanceof ClassNotFoundException)) {
log.warn("Could not initialize CudnnLSTMHelper", t);
} else {
OneTimeLogger.info(log, "cuDNN not found: "
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
}
}
}
/*
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331

View File

@ -21,6 +21,7 @@ import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.nn.api.Model;
@ -126,48 +127,44 @@ public class ParallelWrapperMain {
.build();
if (dataSetIteratorFactoryClazz != null) {
DataSetIteratorProviderFactory dataSetIteratorProviderFactory =
(DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
DataSetIteratorProviderFactory dataSetIteratorProviderFactory = DL4JClassLoading
.createNewInstance(dataSetIteratorFactoryClazz);
DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
if (uiUrl != null) {
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
TrainingListener l;
try {
l = (TrainingListener) Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class)
.newInstance(new Object[]{null});
} catch (ClassNotFoundException e){
throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
}
wrapper.setListeners(remoteUIRouter, l);
TrainingListener trainingListener = DL4JClassLoading.createNewInstance(
"org.deeplearning4j.ui.model.stats.StatsListener",
StatsStorageRouter.class,
new Class[] { StatsStorageRouter.class },
new Object[] { null });
wrapper.setListeners(remoteUIRouter, trainingListener);
}
wrapper.fit(dataSetIterator);
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
} else if (multiDataSetIteratorFactoryClazz != null) {
MultiDataSetProviderFactory multiDataSetProviderFactory =
(MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
MultiDataSetProviderFactory multiDataSetProviderFactory = DL4JClassLoading
.createNewInstance(multiDataSetIteratorFactoryClazz);
MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
if (uiUrl != null) {
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
TrainingListener l;
try {
l = (TrainingListener) Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class)
.newInstance(new Object[]{null});
} catch (ClassNotFoundException e){
throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
}
wrapper.setListeners(remoteUIRouter, l);
TrainingListener trainingListener = DL4JClassLoading
.createNewInstance(
"org.deeplearning4j.ui.model.stats.StatsListener",
TrainingListener.class,
new Class[]{ StatsStorageRouter.class },
new Object[]{ null });
wrapper.setListeners(remoteUIRouter, trainingListener);
}
wrapper.fit(iterator);
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
} else {
throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
}

View File

@ -25,6 +25,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
@ -161,14 +162,9 @@ public class SparkSequenceVectors<T extends SequenceElement> extends SequenceVec
validateConfiguration();
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
String className = configuration.getElementsLearningAlgorithm();
ela = DL4JClassLoading.createNewInstance(className);
}
}
if (workers > 1) {
log.info("Repartitioning corpus to {} parts...", workers);

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
import lombok.NonNull;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
@ -42,24 +43,17 @@ public abstract class BaseTokenizerFunction implements Serializable {
String tpClassName = this.configurationBroadcast.getValue().getTokenPreProcessor();
if (tfClassName != null && !tfClassName.isEmpty()) {
try {
tokenizerFactory = (TokenizerFactory) Class.forName(tfClassName).newInstance();
tokenizerFactory = DL4JClassLoading.createNewInstance(tfClassName);
if (tpClassName != null && !tpClassName.isEmpty()) {
try {
tokenPreprocessor = (TokenPreProcess) Class.forName(tpClassName).newInstance();
} catch (Exception e) {
throw new RuntimeException("Unable to instantiate TokenPreProcessor.", e);
}
tokenPreprocessor = DL4JClassLoading.createNewInstance(tpClassName);
}
if (tokenPreprocessor != null) {
tokenizerFactory.setTokenPreProcessor(tokenPreprocessor);
}
} catch (Exception e) {
throw new RuntimeException("Unable to instantiate TokenizerFactory.", e);
}
} else
} else {
throw new RuntimeException("TokenizerFactory wasn't defined.");
}
}
}

View File

@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
@ -65,13 +66,8 @@ public class CountFunction<T extends SequenceElement> implements Function<Sequen
long seqLen = 0;
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class
.forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm())
.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
String elementsLearningAlgorithm = vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm();
ela = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
driver = ela.getTrainingDriver();

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
import lombok.NonNull;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
@ -74,19 +75,15 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
if (this.elementsLearningAlgorithm == null) {
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
driver = elementsLearningAlgorithm.getTrainingDriver();
driver = this.elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
@ -95,33 +92,24 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
if (elementsLearningAlgorithm != null)
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (this.elementsLearningAlgorithm != null)
this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
}
}
if (sequenceLearningAlgorithm != null)
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (this.sequenceLearningAlgorithm != null)
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
@ -142,7 +130,7 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());

View File

@ -20,6 +20,7 @@ import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
@ -73,19 +74,15 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
if (this.elementsLearningAlgorithm == null) {
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
}
driver = elementsLearningAlgorithm.getTrainingDriver();
driver = this.elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
@ -98,33 +95,23 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
}
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
}
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
/*
at this moment we should have everything ready for actual initialization
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
@ -139,7 +126,7 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
@ -157,7 +144,7 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
// FIXME: temporary hook
if (sequence.size() > 0)
paramServer.execDistributed(
elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
this.elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
else
log.warn("Skipping empty sequence...");

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
import lombok.NonNull;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
@ -56,12 +57,8 @@ public class VocabRddFunctionFlat<T extends SequenceElement> implements FlatMapF
configuration = vectorsConfigurationBroadcast.getValue();
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
String className = configuration.getElementsLearningAlgorithm();
ela = DL4JClassLoading.createNewInstance(className);
}
driver = ela.getTrainingDriver();

View File

@ -17,12 +17,14 @@
package org.deeplearning4j.spark.text.functions;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.function.Function;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
@ -44,35 +46,33 @@ public class TokenizerFunction implements Function<String, List<String>> {
}
@Override
public List<String> call(String v1) throws Exception {
if (tokenizerFactory == null)
public List<String> call(String str) {
if (tokenizerFactory == null) {
tokenizerFactory = getTokenizerFactory();
if (v1.isEmpty())
return Arrays.asList("");
return tokenizerFactory.create(v1).getTokens();
}
if (str.isEmpty()) {
return Collections.singletonList("");
}
return tokenizerFactory.create(str).getTokens();
}
private TokenizerFactory getTokenizerFactory() {
try {
TokenPreProcess tokenPreProcessInst = null;
// token preprocess CAN be undefined
if (tokenizerPreprocessorClazz != null && !tokenizerPreprocessorClazz.isEmpty()) {
Class<? extends TokenPreProcess> clazz =
(Class<? extends TokenPreProcess>) Class.forName(tokenizerPreprocessorClazz);
tokenPreProcessInst = clazz.newInstance();
if (StringUtils.isNotEmpty(tokenizerPreprocessorClazz)) {
tokenPreProcessInst = DL4JClassLoading.createNewInstance(tokenizerPreprocessorClazz);
}
Class<? extends TokenizerFactory> clazz2 =
(Class<? extends TokenizerFactory>) Class.forName(tokenizerFactoryClazz);
tokenizerFactory = clazz2.newInstance();
tokenizerFactory = DL4JClassLoading.createNewInstance(tokenizerFactoryClazz);
if (tokenPreProcessInst != null)
tokenizerFactory.setTokenPreProcessor(tokenPreProcessInst);
if (nGrams > 1) {
tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
}
} catch (Exception e) {
log.error("",e);
}
return tokenizerFactory;
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.spark.time;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.common.config.DL4JSystemProperties;
import java.lang.reflect.Method;
@ -62,9 +63,9 @@ public class TimeSourceProvider {
*/
public static TimeSource getInstance(String className) {
try {
Class<?> c = Class.forName(className);
Method m = c.getMethod("getInstance");
return (TimeSource) m.invoke(null);
Class<?> clazz = DL4JClassLoading.loadClassByName(className);
Method getInstance = clazz.getMethod("getInstance");
return (TimeSource) getInstance.invoke(null);
} catch (Exception e) {
throw new RuntimeException("Error getting TimeSource instance for class \"" + className + "\"", e);
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.ui.model.stats;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
@ -696,11 +697,14 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
return devPointers.get(device);
}
try {
Class<?> c = Class.forName("org.nd4j.jita.allocator.pointers.CudaPointer");
Constructor<?> constructor = c.getConstructor(long.class);
Pointer p = (Pointer) constructor.newInstance((long) device);
devPointers.put(device, p);
return p;
Pointer pointer = DL4JClassLoading.createNewInstance(
"org.nd4j.jita.allocator.pointers.CudaPointer",
Pointer.class,
new Class[] { long.class },
(long) device);
devPointers.put(device, pointer);
return pointer;
} catch (Throwable t) {
devPointers.put(device, null); //Stops attempting the failure again later...
return null;
@ -711,9 +715,9 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
ModelInfo modelInfo = getModelInfo(model);
int examplesThisMinibatch = 0;
if (model instanceof MultiLayerNetwork) {
examplesThisMinibatch = ((MultiLayerNetwork) model).batchSize();
examplesThisMinibatch = model.batchSize();
} else if (model instanceof ComputationGraph) {
examplesThisMinibatch = ((ComputationGraph) model).batchSize();
examplesThisMinibatch = model.batchSize();
} else if (model instanceof Layer) {
examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.ui.model.storage.mapdb;
import lombok.Data;
import lombok.NonNull;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.*;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
@ -318,26 +319,18 @@ public class MapDBStatsStorage extends BaseCollectionStatsStorage {
}
@Override
@SuppressWarnings("unchecked")
public T deserialize(@NonNull DataInput2 input, int available) throws IOException {
int classIdx = input.readInt();
String className = getClassForInt(classIdx);
Class<?> clazz;
try {
clazz = Class.forName(className);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e); //Shouldn't normally happen...
}
Persistable p;
try {
p = (Persistable) clazz.newInstance();
} catch (InstantiationException | IllegalAccessException e) {
throw new RuntimeException(e);
}
int remainingLength = available - 4; //-4 for int class index
Persistable persistable = DL4JClassLoading.createNewInstance(className);
int remainingLength = available - 4; // -4 for int class index
byte[] temp = new byte[remainingLength];
input.readFully(temp);
p.decode(temp);
return (T) p;
persistable.decode(temp);
return (T) persistable;
}
@Override

View File

@ -32,32 +32,42 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.common.config.DL4JSystemProperties;
import org.deeplearning4j.common.util.DL4JFileUtils;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.core.storage.StatsStorageEvent;
import org.deeplearning4j.core.storage.StatsStorageListener;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.common.config.DL4JSystemProperties;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.ui.module.SameDiffModule;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.common.util.DL4JFileUtils;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair;
import java.io.File;
import java.util.*;
import java.util.concurrent.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
@Slf4j
@ -402,8 +412,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
private void modulesViaServiceLoader(List<UIModule> uiModules) {
ServiceLoader<UIModule> sl = ServiceLoader.load(UIModule.class);
ServiceLoader<UIModule> sl = DL4JClassLoading.loadService(UIModule.class);
Iterator<UIModule> iter = sl.iterator();
if (!iter.hasNext()) {
@ -411,19 +420,19 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
while (iter.hasNext()) {
UIModule m = iter.next();
Class<?> c = m.getClass();
UIModule module = iter.next();
Class<?> moduleClass = module.getClass();
boolean foundExisting = false;
for (UIModule mExisting : uiModules) {
if (mExisting.getClass() == c) {
if (mExisting.getClass() == moduleClass) {
foundExisting = true;
break;
}
}
if (!foundExisting) {
log.debug("Loaded UI module via service loader: {}", m.getClass());
uiModules.add(m);
log.debug("Loaded UI module via service loader: {}", module.getClass());
uiModules.add(module);
}
}
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.ui.i18n;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.ui.api.I18N;
import org.deeplearning4j.ui.api.UIModule;
@ -100,13 +101,13 @@ public class DefaultI18N implements I18N {
}
private synchronized void loadLanguages(){
ServiceLoader<UIModule> sl = ServiceLoader.load(UIModule.class);
ServiceLoader<UIModule> loadedModules = DL4JClassLoading.loadService(UIModule.class);
for(UIModule m : sl){
List<I18NResource> resources = m.getInternationalizationResources();
for(I18NResource r : resources){
for (UIModule module : loadedModules){
List<I18NResource> resources = module.getInternationalizationResources();
for(I18NResource resource : resources){
try {
String path = r.getResource();
String path = resource.getResource();
int idxLast = path.lastIndexOf('.');
if (idxLast < 0) {
log.warn("Skipping language resource file: cannot infer language: {}", path);
@ -116,9 +117,9 @@ public class DefaultI18N implements I18N {
String langCode = path.substring(idxLast + 1).toLowerCase();
Map<String, String> map = messagesByLanguage.computeIfAbsent(langCode, k -> new HashMap<>());
parseFile(r, map);
parseFile(resource, map);
} catch (Throwable t){
log.warn("Error parsing UI I18N content file; skipping: {}", r.getResource(), t);
log.warn("Error parsing UI I18N content file; skipping: {}", resource.getResource(), t);
languageLoadingException = t;
}
}

View File

@ -22,6 +22,7 @@ import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.RoutingContext;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.*;
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.ui.api.HttpMethod;
@ -154,9 +155,12 @@ public class RemoteReceiverModule implements UIModule {
private StorageMetaData getMetaData(String dataClass, String content) {
StorageMetaData meta;
try {
Class<?> c = Class.forName(dataClass);
if (StorageMetaData.class.isAssignableFrom(c)) {
meta = (StorageMetaData) c.newInstance();
Class<?> clazz = DL4JClassLoading.loadClassByName(dataClass);
if (StorageMetaData.class.isAssignableFrom(clazz)) {
meta = clazz
.asSubclass(StorageMetaData.class)
.getDeclaredConstructor()
.newInstance();
} else {
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
StorageMetaData.class.getName());
@ -179,11 +183,14 @@ public class RemoteReceiverModule implements UIModule {
}
private Persistable getPersistable(String dataClass, String content) {
Persistable p;
Persistable persistable;
try {
Class<?> c = Class.forName(dataClass);
if (Persistable.class.isAssignableFrom(c)) {
p = (Persistable) c.newInstance();
Class<?> clazz = DL4JClassLoading.loadClassByName(dataClass);
if (Persistable.class.isAssignableFrom(clazz)) {
persistable = clazz
.asSubclass(Persistable.class)
.getDeclaredConstructor()
.newInstance();
} else {
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
Persistable.class.getName());
@ -196,12 +203,12 @@ public class RemoteReceiverModule implements UIModule {
try {
byte[] bytes = DatatypeConverter.parseBase64Binary(content);
p.decode(bytes);
persistable.decode(bytes);
} catch (Exception e) {
log.warn("Skipping invalid remote data: exception encountered when deserializing data", e);
return null;
}
return p;
return persistable;
}
}

View File

@ -22,6 +22,7 @@ import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
@ -127,7 +128,7 @@ public class IntegrationTestRunner {
}
for (ClassPath.ClassInfo c : info) {
Class<?> clazz = Class.forName(c.getName());
Class<?> clazz = DL4JClassLoading.loadClassByName(c.getName());
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
continue;

View File

@ -1,3 +1,19 @@
/*******************************************************************************
* Copyright (c) Eclipse Deeplearning4j Contributors 2020
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.common.config;
import lombok.extern.slf4j.Slf4j;
@ -6,6 +22,22 @@ import java.util.ServiceLoader;
/**
* Global context for class-loading in ND4J.
* <p>Use {@code ND4JClassLoading} to define classloader for ND4J only! To define classloader used by
* {@code Deeplearning4j} use class {@link org.deeplearning4j.common.config.DL4JClassLoading}.
*
* <p>Usage:
* <pre>{@code
* public class Application {
* static {
* ND4JClassLoading.setNd4jClassloaderFromClass(Application.class);
* }
*
* public static void main(String[] args) {
* }
* }
* }</code>
*
* @see org.deeplearning4j.common.config.DL4JClassLoading
*
* @author Alexei KLENIN
*/