FEATURE#8712: add possibility to specify classloader for DL4J (#9115)
Signed-off-by: hosuaby <alexei.klenin@gmail.com>master
parent
2e000c84ac
commit
a722bd5a5b
|
@ -33,6 +33,12 @@
|
|||
<artifactId>nd4j-common</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
@ -43,5 +49,4 @@
|
|||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
</project>
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) Eclipse Deeplearning4j Contributors 2020
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.common.config;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.config.ND4JClassLoading;
|
||||
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.Objects;
|
||||
import java.util.ServiceLoader;
|
||||
|
||||
/**
|
||||
* Global context for class-loading in DL4J.
|
||||
* <p>Use {@code DL4JClassLoading} to define classloader for Deeplearning4j only! To define classloader used by
|
||||
* {@code ND4J} use class {@link org.nd4j.common.config.ND4JClassLoading}.
|
||||
*
|
||||
* <p>Usage:
|
||||
* <pre>{@code
|
||||
* public class Application {
|
||||
* static {
|
||||
* DL4JClassLoading.setDl4jClassloaderFromClass(Application.class);
|
||||
* }
|
||||
*
|
||||
* public static void main(String[] args) {
|
||||
* }
|
||||
* }
|
||||
* }</code>
|
||||
*
|
||||
* @see org.nd4j.common.config.ND4JClassLoading
|
||||
*
|
||||
* @author Alexei KLENIN
|
||||
*/
|
||||
@Slf4j
|
||||
public class DL4JClassLoading {
|
||||
private static ClassLoader dl4jClassloader = ND4JClassLoading.getNd4jClassloader();
|
||||
|
||||
private DL4JClassLoading() {
|
||||
}
|
||||
|
||||
public static ClassLoader getDl4jClassloader() {
|
||||
return DL4JClassLoading.dl4jClassloader;
|
||||
}
|
||||
|
||||
public static void setDl4jClassloaderFromClass(Class<?> clazz) {
|
||||
setDl4jClassloader(clazz.getClassLoader());
|
||||
}
|
||||
|
||||
public static void setDl4jClassloader(ClassLoader dl4jClassloader) {
|
||||
DL4JClassLoading.dl4jClassloader = dl4jClassloader;
|
||||
log.debug("Global class-loader for DL4J was changed.");
|
||||
}
|
||||
|
||||
public static boolean classPresentOnClasspath(String className) {
|
||||
return classPresentOnClasspath(className, dl4jClassloader);
|
||||
}
|
||||
|
||||
public static boolean classPresentOnClasspath(String className, ClassLoader classLoader) {
|
||||
return loadClassByName(className, false, classLoader) != null;
|
||||
}
|
||||
|
||||
public static <T> Class<T> loadClassByName(String className) {
|
||||
return loadClassByName(className, true, dl4jClassloader);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> Class<T> loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
|
||||
try {
|
||||
return (Class<T>) Class.forName(className, initialize, classLoader);
|
||||
} catch (ClassNotFoundException classNotFoundException) {
|
||||
log.error(String.format("Cannot find class [%s] of provided class-loader.", className));
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static <T> T createNewInstance(String className, Object... args) {
|
||||
return createNewInstance(className, Object.class, args);
|
||||
}
|
||||
|
||||
public static <T> T createNewInstance(String className, Class<? super T> superclass) {
|
||||
return createNewInstance(className, superclass, new Class<?>[]{}, new Object[]{});
|
||||
}
|
||||
|
||||
public static <T> T createNewInstance(String className, Class<? super T> superclass, Object... args) {
|
||||
Class<?>[] parameterTypes = new Class<?>[args.length];
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
Object arg = args[i];
|
||||
Objects.requireNonNull(arg);
|
||||
parameterTypes[i] = arg.getClass();
|
||||
}
|
||||
|
||||
return createNewInstance(className, superclass, parameterTypes, args);
|
||||
}
|
||||
|
||||
public static <T> T createNewInstance(
|
||||
String className,
|
||||
Class<? super T> superclass,
|
||||
Class<?>[] parameterTypes,
|
||||
Object... args) {
|
||||
try {
|
||||
return (T) DL4JClassLoading
|
||||
.loadClassByName(className)
|
||||
.asSubclass(superclass)
|
||||
.getDeclaredConstructor(parameterTypes)
|
||||
.newInstance(args);
|
||||
} catch (InstantiationException | IllegalAccessException | InvocationTargetException
|
||||
| NoSuchMethodException instantiationException) {
|
||||
log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
|
||||
throw new RuntimeException(instantiationException);
|
||||
}
|
||||
}
|
||||
|
||||
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass) {
|
||||
return loadService(serviceClass, dl4jClassloader);
|
||||
}
|
||||
|
||||
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass, ClassLoader classLoader) {
|
||||
return ServiceLoader.load(serviceClass, classLoader);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package org.deeplearning4j.common.config;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
import org.deeplearning4j.common.config.dummies.TestAbstract;
|
||||
import org.junit.Test;
|
||||
|
||||
public class DL4JClassLoadingTest {
|
||||
private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies.";
|
||||
|
||||
@Test
|
||||
public void testCreateNewInstance_constructorWithoutArguments() {
|
||||
|
||||
/* Given */
|
||||
String className = PACKAGE_PREFIX + "TestDummy";
|
||||
|
||||
/* When */
|
||||
Object instance = DL4JClassLoading.createNewInstance(className);
|
||||
|
||||
/* Then */
|
||||
assertNotNull(instance);
|
||||
assertEquals(className, instance.getClass().getName());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() {
|
||||
|
||||
/* Given */
|
||||
String className = PACKAGE_PREFIX + "TestColor";
|
||||
|
||||
/* When */
|
||||
TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white");
|
||||
|
||||
/* Then */
|
||||
assertNotNull(instance);
|
||||
assertEquals(className, instance.getClass().getName());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() {
|
||||
|
||||
/* Given */
|
||||
String colorClassName = PACKAGE_PREFIX + "TestColor";
|
||||
String rectangleClassName = PACKAGE_PREFIX + "TestRectangle";
|
||||
|
||||
/* When */
|
||||
TestAbstract color = DL4JClassLoading.createNewInstance(
|
||||
colorClassName,
|
||||
Object.class,
|
||||
new Class<?>[]{ int.class, int.class, int.class },
|
||||
45, 175, 200);
|
||||
|
||||
TestAbstract rectangle = DL4JClassLoading.createNewInstance(
|
||||
rectangleClassName,
|
||||
Object.class,
|
||||
new Class<?>[]{ int.class, int.class, TestAbstract.class },
|
||||
10, 15, color);
|
||||
|
||||
/* Then */
|
||||
assertNotNull(color);
|
||||
assertEquals(colorClassName, color.getClass().getName());
|
||||
|
||||
assertNotNull(rectangle);
|
||||
assertEquals(rectangleClassName, rectangle.getClass().getName());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
package org.deeplearning4j.common.config.dummies;
|
||||
|
||||
public abstract class TestAbstract {
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package org.deeplearning4j.common.config.dummies;
|
||||
|
||||
public class TestColor extends TestAbstract {
|
||||
public TestColor(String color) {
|
||||
}
|
||||
|
||||
public TestColor(int r, int g, int b) {
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package org.deeplearning4j.common.config.dummies;
|
||||
|
||||
public class TestDummy {
|
||||
public TestDummy() {
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package org.deeplearning4j.common.config.dummies;
|
||||
|
||||
public class TestRectangle extends TestAbstract {
|
||||
public TestRectangle(int width, int height, TestAbstract color) {
|
||||
}
|
||||
}
|
|
@ -102,7 +102,6 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
private class ReaderThread<T> extends Thread implements Runnable {
|
||||
private BlockingQueue<T> buffer;
|
||||
private Iterator<T> iterator;
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j;
|
|||
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||
|
@ -66,23 +67,23 @@ public class LayerHelperValidationUtil {
|
|||
|
||||
public static void disableCppHelpers(){
|
||||
try {
|
||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method m = c.getMethod("getInstance");
|
||||
Object instance = m.invoke(null);
|
||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
||||
m2.invoke(instance, false);
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method getInstance = clazz.getMethod("getInstance");
|
||||
Object instance = getInstance.invoke(null);
|
||||
Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
|
||||
allowHelpers.invoke(instance, false);
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
}
|
||||
|
||||
public static void enableCppHelpers(){
|
||||
try{
|
||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method m = c.getMethod("getInstance");
|
||||
Object instance = m.invoke(null);
|
||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
||||
m2.invoke(instance, true);
|
||||
try {
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method getInstance = clazz.getMethod("getInstance");
|
||||
Object instance = getInstance.invoke(null);
|
||||
Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
|
||||
allowHelpers.invoke(instance, true);
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
|
|
|
@ -16,28 +16,96 @@
|
|||
|
||||
package org.deeplearning4j.nn.dtypes;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.conf.*;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||
import org.deeplearning4j.nn.conf.dropout.AlphaDropout;
|
||||
import org.deeplearning4j.nn.conf.dropout.GaussianDropout;
|
||||
import org.deeplearning4j.nn.conf.dropout.GaussianNoise;
|
||||
import org.deeplearning4j.nn.conf.dropout.SpatialDropout;
|
||||
import org.deeplearning4j.nn.conf.graph.*;
|
||||
import org.deeplearning4j.nn.conf.graph.AttentionVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.FrozenVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.L2NormalizeVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.L2Vertex;
|
||||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.PoolHelperVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.ReshapeVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.ScaleVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.ShiftVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.StackVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.UnstackVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
|
||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||
import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Cnn3DLossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution2D;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
|
||||
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
|
||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.LearnedSelfAttentionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
|
||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.PReLULayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Pooling1D;
|
||||
import org.deeplearning4j.nn.conf.layers.Pooling2D;
|
||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
|
||||
import org.deeplearning4j.nn.conf.layers.RecurrentAttentionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SelfAttentionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
|
||||
import org.deeplearning4j.nn.conf.layers.SpaceToBatchLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SpaceToDepthLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling1D;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling3D;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
|
||||
|
@ -49,16 +117,24 @@ import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
|
|||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.layers.util.IdentityLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||
|
@ -77,12 +153,17 @@ import org.nd4j.linalg.learning.config.Nesterovs;
|
|||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Slf4j
|
||||
public class DTypeTests extends BaseDL4JTest {
|
||||
|
@ -120,20 +201,17 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
Set<Class<?>> preprocClasses = new HashSet<>();
|
||||
Set<Class<?>> vertexClasses = new HashSet<>();
|
||||
for (ClassPath.ClassInfo ci : info) {
|
||||
Class<?> clazz;
|
||||
try {
|
||||
clazz = Class.forName(ci.getName());
|
||||
} catch (ClassNotFoundException e) {
|
||||
//Should never happen as this was found on the classpath
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName(ci.getName());
|
||||
|
||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
|
||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) {
|
||||
// Skip TFOpLayer here - dtype depends on imported model dtype
|
||||
continue;
|
||||
}
|
||||
|
||||
if (clazz.getName().toLowerCase().contains("custom") || clazz.getName().contains("samediff.testlayers")
|
||||
|| clazz.getName().toLowerCase().contains("test") || ignoreClasses.contains(clazz)) {
|
||||
if (clazz.getName().toLowerCase().contains("custom")
|
||||
|| clazz.getName().contains("samediff.testlayers")
|
||||
|| clazz.getName().toLowerCase().contains("test")
|
||||
|| ignoreClasses.contains(clazz)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.recurrent;
|
|||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -167,7 +168,7 @@ public class GravesLSTMTest extends BaseDL4JTest {
|
|||
actHelper.setAccessible(true);
|
||||
|
||||
//Call activateHelper with both forBackprop == true, and forBackprop == false and compare
|
||||
Class<?> innerClass = Class.forName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
|
||||
Class<?> innerClass = DL4JClassLoading.loadClassByName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
|
||||
|
||||
Object oFalse = actHelper.invoke(lstm, false, null, null, false, LayerWorkspaceMgr.noWorkspacesImmutable()); //GravesLSTM.FwdPassReturn object; want fwdPassOutput INDArray
|
||||
Object oTrue = actHelper.invoke(lstm, false, null, null, true, LayerWorkspaceMgr.noWorkspacesImmutable()); //want fwdPassOutputAsArrays object
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
|
|||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||
|
@ -110,7 +111,7 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
|||
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
|
||||
byte[] graphBytes = serialized.toByteArray();
|
||||
|
||||
ServiceLoader<TFGraphRunnerService> sl = ServiceLoader.load(TFGraphRunnerService.class);
|
||||
ServiceLoader<TFGraphRunnerService> sl = DL4JClassLoading.loadService(TFGraphRunnerService.class);
|
||||
Iterator<TFGraphRunnerService> iter = sl.iterator();
|
||||
if (!iter.hasNext()){
|
||||
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
|
||||
|
|
|
@ -61,14 +61,12 @@ public abstract class Model {
|
|||
/**
|
||||
* 模型读取
|
||||
*
|
||||
* @param path
|
||||
* @return
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public static Model load(Class<? extends Model> c, InputStream is) throws Exception {
|
||||
Model model = c.newInstance();
|
||||
return model.loadModel(is);
|
||||
public static Model load(Class<? extends Model> modelClass, InputStream inputStream) throws Exception {
|
||||
return modelClass
|
||||
.getDeclaredConstructor()
|
||||
.newInstance()
|
||||
.loadModel(inputStream);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -5,6 +5,7 @@ import org.ansj.dic.impl.Jar2Stream;
|
|||
import org.ansj.dic.impl.Jdbc2Stream;
|
||||
import org.ansj.dic.impl.Url2Stream;
|
||||
import org.ansj.exception.LibraryException;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
|
@ -25,7 +26,8 @@ public abstract class PathToStream {
|
|||
} else if (path.startsWith("jar://")) {
|
||||
return new Jar2Stream().toStream(path);
|
||||
} else if (path.startsWith("class://")) {
|
||||
((PathToStream) Class.forName(path.substring(8).split("\\|")[0]).newInstance()).toStream(path);
|
||||
// Probably unused
|
||||
return loadClass(path);
|
||||
} else if (path.startsWith("http://") || path.startsWith("https://")) {
|
||||
return new Url2Stream().toStream(path);
|
||||
} else {
|
||||
|
@ -34,9 +36,17 @@ public abstract class PathToStream {
|
|||
} catch (Exception e) {
|
||||
throw new LibraryException(e);
|
||||
}
|
||||
throw new LibraryException("not find method type in path " + path);
|
||||
}
|
||||
|
||||
public abstract InputStream toStream(String path);
|
||||
|
||||
static InputStream loadClass(String path) {
|
||||
String className = path
|
||||
.substring("class://".length())
|
||||
.split("\\|")[0];
|
||||
|
||||
return DL4JClassLoading
|
||||
.createNewInstance(className, PathToStream.class)
|
||||
.toStream(path);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package org.ansj.dic.impl;
|
|||
import org.ansj.dic.DicReader;
|
||||
import org.ansj.dic.PathToStream;
|
||||
import org.ansj.exception.LibraryException;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
|
@ -17,12 +18,16 @@ public class Jar2Stream extends PathToStream {
|
|||
@Override
|
||||
public InputStream toStream(String path) {
|
||||
if (path.contains("|")) {
|
||||
String[] split = path.split("\\|");
|
||||
try {
|
||||
return Class.forName(split[0].substring(6)).getResourceAsStream(split[1].trim());
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw new LibraryException(e);
|
||||
String[] tokens = path.split("\\|");
|
||||
String className = tokens[0].substring(6);
|
||||
String resourceName = tokens[1].trim();
|
||||
|
||||
Class<Object> resourceClass = DL4JClassLoading.loadClassByName(className);
|
||||
if (resourceClass == null) {
|
||||
throw new LibraryException(String.format("Class '%s' was not found.", className));
|
||||
}
|
||||
|
||||
return resourceClass.getResourceAsStream(resourceName);
|
||||
} else {
|
||||
return DicReader.getInputStream(path.substring(6));
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package org.ansj.dic.impl;
|
|||
|
||||
import org.ansj.dic.PathToStream;
|
||||
import org.ansj.exception.LibraryException;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
|
@ -22,20 +23,26 @@ public class Jdbc2Stream extends PathToStream {
|
|||
|
||||
private static final byte[] LINE = "\n".getBytes();
|
||||
|
||||
private static final String[] JDBC_DRIVERS = {
|
||||
"org.h2.Driver",
|
||||
"com.ibm.db2.jcc.DB2Driver",
|
||||
"org.hsqldb.jdbcDriver",
|
||||
"org.gjt.mm.mysql.Driver",
|
||||
"oracle.jdbc.OracleDriver",
|
||||
"org.postgresql.Driver",
|
||||
"net.sourceforge.jtds.jdbc.Driver",
|
||||
"com.microsoft.sqlserver.jdbc.SQLServerDriver",
|
||||
"org.sqlite.JDBC",
|
||||
"com.mysql.jdbc.Driver"
|
||||
};
|
||||
|
||||
static {
|
||||
String[] drivers = {"org.h2.Driver", "com.ibm.db2.jcc.DB2Driver", "org.hsqldb.jdbcDriver",
|
||||
"org.gjt.mm.mysql.Driver", "oracle.jdbc.OracleDriver", "org.postgresql.Driver",
|
||||
"net.sourceforge.jtds.jdbc.Driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver",
|
||||
"org.sqlite.JDBC", "com.mysql.jdbc.Driver"};
|
||||
for (String driverClassName : drivers) {
|
||||
try {
|
||||
try {
|
||||
Thread.currentThread().getContextClassLoader().loadClass(driverClassName);
|
||||
} catch (ClassNotFoundException e) {
|
||||
Class.forName(driverClassName);
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
loadJdbcDrivers();
|
||||
}
|
||||
|
||||
static void loadJdbcDrivers() {
|
||||
for (String driverClassName : JDBC_DRIVERS) {
|
||||
DL4JClassLoading.loadClassByName(driverClassName);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,44 +17,20 @@
|
|||
|
||||
package org.deeplearning4j.models.embeddings.loader;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.DataInputStream;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.FileReader;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.io.OutputStreamWriter;
|
||||
import java.io.PrintWriter;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
import java.util.zip.ZipInputStream;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.codec.binary.Base64;
|
||||
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.LineIterator;
|
||||
import org.apache.commons.io.output.CloseShieldOutputStream;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
|
@ -94,12 +70,37 @@ import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||
import org.nd4j.storage.CompressedRamStorage;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.DataInputStream;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.FileReader;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.io.OutputStreamWriter;
|
||||
import java.io.PrintWriter;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
import java.util.zip.ZipInputStream;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
|
||||
/**
|
||||
* This is utility class, providing various methods for WordVectors serialization
|
||||
|
@ -2676,26 +2677,23 @@ public class WordVectorSerializer {
|
|||
}
|
||||
|
||||
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
|
||||
if (configuration == null)
|
||||
if (configuration == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (configuration.getTokenizerFactory() != null && !configuration.getTokenizerFactory().isEmpty()) {
|
||||
try {
|
||||
TokenizerFactory factory =
|
||||
(TokenizerFactory) Class.forName(configuration.getTokenizerFactory()).newInstance();
|
||||
String tokenizerFactoryClassName = configuration.getTokenizerFactory();
|
||||
if (StringUtils.isNotEmpty(tokenizerFactoryClassName)) {
|
||||
TokenizerFactory factory = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
|
||||
|
||||
if (configuration.getTokenPreProcessor() != null && !configuration.getTokenPreProcessor().isEmpty()) {
|
||||
TokenPreProcess preProcessor =
|
||||
(TokenPreProcess) Class.forName(configuration.getTokenPreProcessor()).newInstance();
|
||||
String tokenPreProcessorClassName = configuration.getTokenPreProcessor();
|
||||
if (StringUtils.isNotEmpty(tokenPreProcessorClassName)) {
|
||||
TokenPreProcess preProcessor = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
|
||||
factory.setTokenPreProcessor(preProcessor);
|
||||
}
|
||||
|
||||
return factory;
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Can't instantiate saved TokenizerFactory: {}", configuration.getTokenizerFactory());
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.deeplearning4j.models.sequencevectors;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.nd4j.shade.guava.primitives.Ints;
|
||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||
import lombok.Getter;
|
||||
|
@ -494,15 +496,17 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
|||
this.useHierarchicSoftmax = configuration.isUseHierarchicSoftmax();
|
||||
this.preciseMode = configuration.isPreciseMode();
|
||||
|
||||
if (configuration.getModelUtils() != null && !configuration.getModelUtils().isEmpty()) {
|
||||
|
||||
String modelUtilsClassName = configuration.getModelUtils();
|
||||
if (StringUtils.isNotEmpty(modelUtilsClassName)) {
|
||||
try {
|
||||
this.modelUtils = (ModelUtils<T>) Class.forName(configuration.getModelUtils()).newInstance();
|
||||
} catch (Exception e) {
|
||||
log.error("Got {} trying to instantiate ModelUtils, falling back to BasicModelUtils instead");
|
||||
this.modelUtils = DL4JClassLoading.createNewInstance(modelUtilsClassName);
|
||||
} catch (Exception instantiationException) {
|
||||
log.error(
|
||||
"Got '{}' trying to instantiate ModelUtils, falling back to BasicModelUtils instead",
|
||||
instantiationException.getMessage(),
|
||||
instantiationException);
|
||||
this.modelUtils = new BasicModelUtils<>();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (configuration.getElementsLearningAlgorithm() != null
|
||||
|
@ -551,12 +555,7 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
|||
* @return
|
||||
*/
|
||||
public Builder<T> sequenceLearningAlgorithm(@NonNull String algoName) {
|
||||
try {
|
||||
Class clazz = Class.forName(algoName);
|
||||
sequenceLearningAlgorithm = (SequenceLearningAlgorithm<T>) clazz.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -578,13 +577,9 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
|||
* @return
|
||||
*/
|
||||
public Builder<T> elementsLearningAlgorithm(@NonNull String algoName) {
|
||||
try {
|
||||
Class clazz = Class.forName(algoName);
|
||||
elementsLearningAlgorithm = (ElementsLearningAlgorithm<T>) clazz.newInstance();
|
||||
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
|
||||
this.configuration.setElementsLearningAlgorithm(elementsLearningAlgorithm.getClass().getCanonicalName());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -943,31 +938,23 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
|||
.lr(learningRate).seed(seed).build();
|
||||
}
|
||||
|
||||
if (this.configuration.getElementsLearningAlgorithm() != null) {
|
||||
try {
|
||||
elementsLearningAlgorithm = (ElementsLearningAlgorithm<T>) Class
|
||||
.forName(this.configuration.getElementsLearningAlgorithm()).newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
String elementsLearningAlgorithm = this.configuration.getElementsLearningAlgorithm();
|
||||
if (StringUtils.isNotEmpty(elementsLearningAlgorithm)) {
|
||||
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||
}
|
||||
|
||||
if (this.configuration.getSequenceLearningAlgorithm() != null) {
|
||||
try {
|
||||
sequenceLearningAlgorithm = (SequenceLearningAlgorithm<T>) Class
|
||||
.forName(this.configuration.getSequenceLearningAlgorithm()).newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
String sequenceLearningAlgorithm = this.configuration.getSequenceLearningAlgorithm();
|
||||
if (StringUtils.isNotEmpty(sequenceLearningAlgorithm)) {
|
||||
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
|
||||
}
|
||||
|
||||
if (trainElementsVectors && elementsLearningAlgorithm == null) {
|
||||
if (trainElementsVectors && this.elementsLearningAlgorithm == null) {
|
||||
// create default implementation of ElementsLearningAlgorithm
|
||||
elementsLearningAlgorithm = new SkipGram<>();
|
||||
this.elementsLearningAlgorithm = new SkipGram<>();
|
||||
}
|
||||
|
||||
if (trainSequenceVectors && sequenceLearningAlgorithm == null) {
|
||||
sequenceLearningAlgorithm = new DBOW<>();
|
||||
if (trainSequenceVectors && this.sequenceLearningAlgorithm == null) {
|
||||
this.sequenceLearningAlgorithm = new DBOW<>();
|
||||
}
|
||||
|
||||
this.modelUtils.init(lookupTable);
|
||||
|
|
|
@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
@ -132,25 +133,20 @@ public class Dropout implements IDropout {
|
|||
*/
|
||||
protected void initializeHelper(DataType dataType){
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
|
||||
if("CUDA".equalsIgnoreCase(backend)) {
|
||||
try {
|
||||
helper = Class.forName("org.deeplearning4j.cuda.dropout.CudnnDropoutHelper")
|
||||
.asSubclass(DropoutHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
||||
helper = DL4JClassLoading.createNewInstance(
|
||||
"org.deeplearning4j.cuda.dropout.CudnnDropoutHelper",
|
||||
DropoutHelper.class,
|
||||
dataType);
|
||||
log.debug("CudnnDropoutHelper successfully initialized");
|
||||
if (!helper.checkSupported()) {
|
||||
helper = null;
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
if (!(t instanceof ClassNotFoundException)) {
|
||||
log.warn("Could not initialize CudnnDropoutHelper", t);
|
||||
}
|
||||
//Unlike other layers, don't warn here about CuDNN not found - if the user has any other layers that can
|
||||
// benefit from them cudnn, they will get a warning from those
|
||||
}
|
||||
}
|
||||
initializedHelper = true;
|
||||
}
|
||||
|
||||
initializedHelper = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
|
||||
|
|
|
@ -43,6 +43,7 @@ import org.nd4j.shade.jackson.databind.deser.std.StdDeserializer;
|
|||
import org.nd4j.shade.jackson.databind.node.ObjectNode;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -268,14 +269,17 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
|
|||
|
||||
//Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" : <object>
|
||||
protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){
|
||||
|
||||
if(baseLayer.getActivationFn() == null && on.has("activationFunction")){
|
||||
String afn = on.get("activationFunction").asText();
|
||||
IActivation a = null;
|
||||
try {
|
||||
a = getMap().get(afn.toLowerCase()).newInstance();
|
||||
} catch (InstantiationException | IllegalAccessException e){
|
||||
//Ignore
|
||||
a = getMap()
|
||||
.get(afn.toLowerCase())
|
||||
.getDeclaredConstructor()
|
||||
.newInstance();
|
||||
} catch (InstantiationException | IllegalAccessException | NoSuchMethodException
|
||||
| InvocationTargetException instantiationException){
|
||||
log.error(instantiationException.getMessage());
|
||||
}
|
||||
baseLayer.setActivationFn(a);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution;
|
|||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
|
@ -73,26 +74,19 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
void initializeHelper() {
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
if("CUDA".equalsIgnoreCase(backend)) {
|
||||
try {
|
||||
helper = Class.forName("org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper")
|
||||
.asSubclass(ConvolutionHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
||||
helper = DL4JClassLoading.createNewInstance(
|
||||
"org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper",
|
||||
ConvolutionHelper.class,
|
||||
dataType);
|
||||
log.debug("CudnnConvolutionHelper successfully initialized");
|
||||
if (!helper.checkSupported()) {
|
||||
helper = null;
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
if (!(t instanceof ClassNotFoundException)) {
|
||||
log.warn("Could not initialize CudnnConvolutionHelper", t);
|
||||
} else {
|
||||
OneTimeLogger.info(log, "cuDNN not found: "
|
||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||
}
|
||||
}
|
||||
} else if("CPU".equalsIgnoreCase(backend)){
|
||||
helper = new MKLDNNConvHelper(dataType);
|
||||
log.trace("Created MKLDNNConvHelper, layer {}", layerConf().getLayerName());
|
||||
}
|
||||
|
||||
if (helper != null && !helper.checkSupported()) {
|
||||
log.debug("Removed helper {} as not supported", helper.getClass());
|
||||
helper = null;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.layers.convolution.subsampling;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
|
@ -30,17 +31,15 @@ import org.deeplearning4j.nn.layers.mkldnn.MKLDNNSubsamplingHelper;
|
|||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.util.ConvolutionUtils;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.common.util.OneTimeLogger;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
|
||||
/**
|
||||
* Subsampling layer.
|
||||
*
|
||||
|
@ -64,27 +63,21 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
|||
|
||||
void initializeHelper() {
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
|
||||
if("CUDA".equalsIgnoreCase(backend)) {
|
||||
try {
|
||||
helper = Class.forName("org.deeplearning4j.cuda.convolution.subsampling.CudnnSubsamplingHelper")
|
||||
.asSubclass(SubsamplingHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
||||
helper = DL4JClassLoading.createNewInstance(
|
||||
"org.deeplearning4j.cuda.convolution.subsampling.CudnnSubsamplingHelper",
|
||||
SubsamplingHelper.class,
|
||||
dataType);
|
||||
log.debug("CudnnSubsamplingHelper successfully initialized");
|
||||
if (!helper.checkSupported()) {
|
||||
helper = null;
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
if (!(t instanceof ClassNotFoundException)) {
|
||||
log.warn("Could not initialize CudnnSubsamplingHelper", t);
|
||||
} else {
|
||||
OneTimeLogger.info(log, "cuDNN not found: "
|
||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||
}
|
||||
}
|
||||
} else if("CPU".equalsIgnoreCase(backend) ){
|
||||
helper = new MKLDNNSubsamplingHelper(dataType);
|
||||
log.trace("Created MKL-DNN helper: MKLDNNSubsamplingHelper, layer {}", layerConf().getLayerName());
|
||||
}
|
||||
|
||||
if (helper != null && !helper.checkSupported()) {
|
||||
log.debug("Removed helper {} as not supported", helper.getClass());
|
||||
helper = null;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.nn.layers.mkldnn;
|
||||
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
|
@ -47,12 +48,11 @@ public class BaseMKLDNNHelper {
|
|||
}
|
||||
|
||||
try{
|
||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method m = c.getMethod("getInstance");
|
||||
Object instance = m.invoke(null);
|
||||
Method m2 = c.getMethod("isUseMKLDNN");
|
||||
boolean b = (Boolean)m2.invoke(instance);
|
||||
return b;
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method getInstance = clazz.getMethod("getInstance");
|
||||
Object instance = getInstance.invoke(null);
|
||||
Method isUseMKLDNNMethod = clazz.getMethod("isUseMKLDNN");
|
||||
return (boolean) isUseMKLDNNMethod.invoke(instance);
|
||||
} catch (Throwable t ){
|
||||
FAILED_CHECK = new AtomicBoolean(true);
|
||||
return false;
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
|
@ -75,24 +76,18 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
|||
|
||||
void initializeHelper() {
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
if("CUDA".equalsIgnoreCase(backend)) {
|
||||
try {
|
||||
helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper")
|
||||
.asSubclass(BatchNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
||||
|
||||
if ("CUDA".equalsIgnoreCase(backend)) {
|
||||
helper = DL4JClassLoading.createNewInstance(
|
||||
"org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper",
|
||||
BatchNormalizationHelper.class,
|
||||
dataType);
|
||||
log.debug("CudnnBatchNormalizationHelper successfully initialized");
|
||||
} catch (Throwable t) {
|
||||
if (!(t instanceof ClassNotFoundException)) {
|
||||
log.warn("Could not initialize CudnnBatchNormalizationHelper", t);
|
||||
} else {
|
||||
OneTimeLogger.info(log, "cuDNN not found: "
|
||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||
}
|
||||
}
|
||||
} else if("CPU".equalsIgnoreCase(backend)){
|
||||
} else if ("CPU".equalsIgnoreCase(backend)){
|
||||
helper = new MKLDNNBatchNormHelper(dataType);
|
||||
log.trace("Created MKLDNNBatchNormHelper, layer {}", layerConf().getLayerName());
|
||||
}
|
||||
|
||||
if (helper != null && !helper.checkSupported(layerConf().getEps(), layerConf().isLockGammaBeta())) {
|
||||
log.debug("Removed helper {} as not supported with epsilon {}, lockGammaBeta={}", helper.getClass(), layerConf().getEps(), layerConf().isLockGammaBeta());
|
||||
helper = null;
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
package org.deeplearning4j.nn.layers.normalization;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -36,9 +38,6 @@ import org.nd4j.common.primitives.Pair;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
import org.nd4j.common.util.OneTimeLogger;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||
|
||||
|
@ -65,10 +64,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
|||
* <p>
|
||||
* Created by nyghtowl on 10/29/15.
|
||||
*/
|
||||
@Slf4j
|
||||
public class LocalResponseNormalization
|
||||
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
|
||||
protected static final Logger log =
|
||||
LoggerFactory.getLogger(org.deeplearning4j.nn.conf.layers.LocalResponseNormalization.class);
|
||||
|
||||
protected LocalResponseNormalizationHelper helper = null;
|
||||
protected int helperCountFail = 0;
|
||||
|
@ -86,19 +84,11 @@ public class LocalResponseNormalization
|
|||
void initializeHelper() {
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
if("CUDA".equalsIgnoreCase(backend)) {
|
||||
try {
|
||||
helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper")
|
||||
.asSubclass(LocalResponseNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
||||
helper = DL4JClassLoading.createNewInstance(
|
||||
"org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper",
|
||||
LocalResponseNormalizationHelper.class,
|
||||
dataType);
|
||||
log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
|
||||
} catch (Throwable t) {
|
||||
if (!(t instanceof ClassNotFoundException)) {
|
||||
log.warn("Could not initialize CudnnLocalResponseNormalizationHelper", t);
|
||||
} else {
|
||||
OneTimeLogger.info(log, "cuDNN not found: "
|
||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||
}
|
||||
}
|
||||
}
|
||||
//2019-03-09 AB - MKL-DNN helper disabled: https://github.com/deeplearning4j/deeplearning4j/issues/7272
|
||||
// else if("CPU".equalsIgnoreCase(backend)){
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.layers.recurrent;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -58,22 +59,14 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
|||
void initializeHelper() {
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
if("CUDA".equalsIgnoreCase(backend)) {
|
||||
try {
|
||||
helper = Class.forName("org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper")
|
||||
.asSubclass(LSTMHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
||||
helper = DL4JClassLoading.createNewInstance(
|
||||
"org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper",
|
||||
LSTMHelper.class,
|
||||
dataType);
|
||||
log.debug("CudnnLSTMHelper successfully initialized");
|
||||
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
|
||||
helper = null;
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
if (!(t instanceof ClassNotFoundException)) {
|
||||
log.warn("Could not initialize CudnnLSTMHelper", t);
|
||||
} else {
|
||||
OneTimeLogger.info(log, "cuDNN not found: "
|
||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
|
||||
|
|
|
@ -21,6 +21,7 @@ import com.beust.jcommander.Parameter;
|
|||
import com.beust.jcommander.ParameterException;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
||||
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
|
@ -126,48 +127,44 @@ public class ParallelWrapperMain {
|
|||
.build();
|
||||
|
||||
if (dataSetIteratorFactoryClazz != null) {
|
||||
DataSetIteratorProviderFactory dataSetIteratorProviderFactory =
|
||||
(DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
|
||||
DataSetIteratorProviderFactory dataSetIteratorProviderFactory = DL4JClassLoading
|
||||
.createNewInstance(dataSetIteratorFactoryClazz);
|
||||
|
||||
DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
|
||||
if (uiUrl != null) {
|
||||
// it's important that the UI can report results from parallel training
|
||||
// there's potential for StatsListener to fail if certain properties aren't set in the model
|
||||
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
|
||||
TrainingListener l;
|
||||
try {
|
||||
l = (TrainingListener) Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class)
|
||||
.newInstance(new Object[]{null});
|
||||
} catch (ClassNotFoundException e){
|
||||
throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
|
||||
}
|
||||
wrapper.setListeners(remoteUIRouter, l);
|
||||
TrainingListener trainingListener = DL4JClassLoading.createNewInstance(
|
||||
"org.deeplearning4j.ui.model.stats.StatsListener",
|
||||
StatsStorageRouter.class,
|
||||
new Class[] { StatsStorageRouter.class },
|
||||
new Object[] { null });
|
||||
wrapper.setListeners(remoteUIRouter, trainingListener);
|
||||
|
||||
}
|
||||
wrapper.fit(dataSetIterator);
|
||||
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
|
||||
|
||||
|
||||
} else if (multiDataSetIteratorFactoryClazz != null) {
|
||||
MultiDataSetProviderFactory multiDataSetProviderFactory =
|
||||
(MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
|
||||
MultiDataSetProviderFactory multiDataSetProviderFactory = DL4JClassLoading
|
||||
.createNewInstance(multiDataSetIteratorFactoryClazz);
|
||||
|
||||
MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
|
||||
if (uiUrl != null) {
|
||||
// it's important that the UI can report results from parallel training
|
||||
// there's potential for StatsListener to fail if certain properties aren't set in the model
|
||||
remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
|
||||
TrainingListener l;
|
||||
try {
|
||||
l = (TrainingListener) Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class)
|
||||
.newInstance(new Object[]{null});
|
||||
} catch (ClassNotFoundException e){
|
||||
throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
|
||||
}
|
||||
wrapper.setListeners(remoteUIRouter, l);
|
||||
TrainingListener trainingListener = DL4JClassLoading
|
||||
.createNewInstance(
|
||||
"org.deeplearning4j.ui.model.stats.StatsListener",
|
||||
TrainingListener.class,
|
||||
new Class[]{ StatsStorageRouter.class },
|
||||
new Object[]{ null });
|
||||
wrapper.setListeners(remoteUIRouter, trainingListener);
|
||||
|
||||
}
|
||||
wrapper.fit(iterator);
|
||||
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
|
||||
|
||||
} else {
|
||||
throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.api.java.JavaRDD;
|
|||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.apache.spark.storage.StorageLevel;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||
|
@ -161,14 +162,9 @@ public class SparkSequenceVectors<T extends SequenceElement> extends SequenceVec
|
|||
validateConfiguration();
|
||||
|
||||
if (ela == null) {
|
||||
try {
|
||||
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
|
||||
.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
String className = configuration.getElementsLearningAlgorithm();
|
||||
ela = DL4JClassLoading.createNewInstance(className);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (workers > 1) {
|
||||
log.info("Repartitioning corpus to {} parts...", workers);
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
|
|||
|
||||
import lombok.NonNull;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
|
@ -42,24 +43,17 @@ public abstract class BaseTokenizerFunction implements Serializable {
|
|||
String tpClassName = this.configurationBroadcast.getValue().getTokenPreProcessor();
|
||||
|
||||
if (tfClassName != null && !tfClassName.isEmpty()) {
|
||||
try {
|
||||
tokenizerFactory = (TokenizerFactory) Class.forName(tfClassName).newInstance();
|
||||
tokenizerFactory = DL4JClassLoading.createNewInstance(tfClassName);
|
||||
|
||||
if (tpClassName != null && !tpClassName.isEmpty()) {
|
||||
try {
|
||||
tokenPreprocessor = (TokenPreProcess) Class.forName(tpClassName).newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to instantiate TokenPreProcessor.", e);
|
||||
}
|
||||
tokenPreprocessor = DL4JClassLoading.createNewInstance(tpClassName);
|
||||
}
|
||||
|
||||
if (tokenPreprocessor != null) {
|
||||
tokenizerFactory.setTokenPreProcessor(tokenPreprocessor);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to instantiate TokenizerFactory.", e);
|
||||
}
|
||||
} else
|
||||
} else {
|
||||
throw new RuntimeException("TokenizerFactory wasn't defined.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.spark.Accumulator;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
@ -65,13 +66,8 @@ public class CountFunction<T extends SequenceElement> implements Function<Sequen
|
|||
long seqLen = 0;
|
||||
|
||||
if (ela == null) {
|
||||
try {
|
||||
ela = (SparkElementsLearningAlgorithm) Class
|
||||
.forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm())
|
||||
.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
String elementsLearningAlgorithm = vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm();
|
||||
ela = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||
}
|
||||
driver = ela.getTrainingDriver();
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
|
|||
import lombok.NonNull;
|
||||
import org.apache.spark.api.java.function.VoidFunction;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
@ -74,19 +75,15 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
|
|||
if (vectorsConfiguration == null)
|
||||
vectorsConfiguration = configurationBroadcast.getValue();
|
||||
|
||||
String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
|
||||
if (paramServer == null) {
|
||||
paramServer = VoidParameterServer.getInstance();
|
||||
|
||||
if (elementsLearningAlgorithm == null) {
|
||||
try {
|
||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
if (this.elementsLearningAlgorithm == null) {
|
||||
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||
}
|
||||
|
||||
driver = elementsLearningAlgorithm.getTrainingDriver();
|
||||
driver = this.elementsLearningAlgorithm.getTrainingDriver();
|
||||
|
||||
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
|
||||
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
|
||||
|
@ -95,33 +92,24 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
|
|||
if (shallowVocabCache == null)
|
||||
shallowVocabCache = vocabCacheBroadcast.getValue();
|
||||
|
||||
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
|
||||
if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
|
||||
// TODO: do ELA initialization
|
||||
try {
|
||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||
}
|
||||
|
||||
if (elementsLearningAlgorithm != null)
|
||||
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
if (this.elementsLearningAlgorithm != null)
|
||||
this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
|
||||
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
|
||||
String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
|
||||
if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
|
||||
// TODO: do SLA initialization
|
||||
try {
|
||||
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
|
||||
.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
|
||||
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
|
||||
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
}
|
||||
}
|
||||
if (sequenceLearningAlgorithm != null)
|
||||
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
if (this.sequenceLearningAlgorithm != null)
|
||||
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
|
||||
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
|
||||
if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
|
||||
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
|
||||
}
|
||||
|
||||
|
@ -142,7 +130,7 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
|
|||
}
|
||||
|
||||
// do the same with labels, transfer them, if any
|
||||
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
||||
if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
||||
for (T label : sequence.getSequenceLabels()) {
|
||||
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.NonNull;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.spark.api.java.function.VoidFunction;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
@ -73,19 +74,15 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
|||
if (vectorsConfiguration == null)
|
||||
vectorsConfiguration = configurationBroadcast.getValue();
|
||||
|
||||
String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
|
||||
if (paramServer == null) {
|
||||
paramServer = VoidParameterServer.getInstance();
|
||||
|
||||
if (elementsLearningAlgorithm == null) {
|
||||
try {
|
||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
if (this.elementsLearningAlgorithm == null) {
|
||||
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||
}
|
||||
|
||||
driver = elementsLearningAlgorithm.getTrainingDriver();
|
||||
driver = this.elementsLearningAlgorithm.getTrainingDriver();
|
||||
|
||||
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
|
||||
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
|
||||
|
@ -98,33 +95,23 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
|||
shallowVocabCache = vocabCacheBroadcast.getValue();
|
||||
|
||||
|
||||
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
|
||||
if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
|
||||
// TODO: do ELA initialization
|
||||
try {
|
||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
||||
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||
this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
}
|
||||
|
||||
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
|
||||
String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
|
||||
if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
|
||||
// TODO: do SLA initialization
|
||||
try {
|
||||
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
|
||||
.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
|
||||
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
|
||||
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||
}
|
||||
|
||||
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
|
||||
if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
|
||||
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
at this moment we should have everything ready for actual initialization
|
||||
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
|
||||
|
@ -139,7 +126,7 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
|||
}
|
||||
|
||||
// do the same with labels, transfer them, if any
|
||||
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
||||
if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
||||
for (T label : sequence.getSequenceLabels()) {
|
||||
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
|
||||
|
||||
|
@ -157,7 +144,7 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
|||
// FIXME: temporary hook
|
||||
if (sequence.size() > 0)
|
||||
paramServer.execDistributed(
|
||||
elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
|
||||
this.elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
|
||||
else
|
||||
log.warn("Skipping empty sequence...");
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
|
|||
import lombok.NonNull;
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
@ -56,12 +57,8 @@ public class VocabRddFunctionFlat<T extends SequenceElement> implements FlatMapF
|
|||
configuration = vectorsConfigurationBroadcast.getValue();
|
||||
|
||||
if (ela == null) {
|
||||
try {
|
||||
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
|
||||
.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
String className = configuration.getElementsLearningAlgorithm();
|
||||
ela = DL4JClassLoading.createNewInstance(className);
|
||||
}
|
||||
driver = ela.getTrainingDriver();
|
||||
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
package org.deeplearning4j.spark.text.functions;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -44,35 +46,33 @@ public class TokenizerFunction implements Function<String, List<String>> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<String> call(String v1) throws Exception {
|
||||
if (tokenizerFactory == null)
|
||||
public List<String> call(String str) {
|
||||
if (tokenizerFactory == null) {
|
||||
tokenizerFactory = getTokenizerFactory();
|
||||
if (v1.isEmpty())
|
||||
return Arrays.asList("");
|
||||
return tokenizerFactory.create(v1).getTokens();
|
||||
}
|
||||
|
||||
if (str.isEmpty()) {
|
||||
return Collections.singletonList("");
|
||||
}
|
||||
|
||||
return tokenizerFactory.create(str).getTokens();
|
||||
}
|
||||
|
||||
private TokenizerFactory getTokenizerFactory() {
|
||||
try {
|
||||
TokenPreProcess tokenPreProcessInst = null;
|
||||
// token preprocess CAN be undefined
|
||||
if (tokenizerPreprocessorClazz != null && !tokenizerPreprocessorClazz.isEmpty()) {
|
||||
Class<? extends TokenPreProcess> clazz =
|
||||
(Class<? extends TokenPreProcess>) Class.forName(tokenizerPreprocessorClazz);
|
||||
tokenPreProcessInst = clazz.newInstance();
|
||||
|
||||
if (StringUtils.isNotEmpty(tokenizerPreprocessorClazz)) {
|
||||
tokenPreProcessInst = DL4JClassLoading.createNewInstance(tokenizerPreprocessorClazz);
|
||||
}
|
||||
|
||||
Class<? extends TokenizerFactory> clazz2 =
|
||||
(Class<? extends TokenizerFactory>) Class.forName(tokenizerFactoryClazz);
|
||||
tokenizerFactory = clazz2.newInstance();
|
||||
tokenizerFactory = DL4JClassLoading.createNewInstance(tokenizerFactoryClazz);
|
||||
|
||||
if (tokenPreProcessInst != null)
|
||||
tokenizerFactory.setTokenPreProcessor(tokenPreProcessInst);
|
||||
if (nGrams > 1) {
|
||||
tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
}
|
||||
|
||||
return tokenizerFactory;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.spark.time;
|
||||
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.common.config.DL4JSystemProperties;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
|
@ -62,9 +63,9 @@ public class TimeSourceProvider {
|
|||
*/
|
||||
public static TimeSource getInstance(String className) {
|
||||
try {
|
||||
Class<?> c = Class.forName(className);
|
||||
Method m = c.getMethod("getInstance");
|
||||
return (TimeSource) m.invoke(null);
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName(className);
|
||||
Method getInstance = clazz.getMethod("getInstance");
|
||||
return (TimeSource) getInstance.invoke(null);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error getting TimeSource instance for class \"" + className + "\"", e);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.ui.model.stats;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
||||
import org.deeplearning4j.core.storage.StorageMetaData;
|
||||
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
|
||||
|
@ -696,11 +697,14 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
|
|||
return devPointers.get(device);
|
||||
}
|
||||
try {
|
||||
Class<?> c = Class.forName("org.nd4j.jita.allocator.pointers.CudaPointer");
|
||||
Constructor<?> constructor = c.getConstructor(long.class);
|
||||
Pointer p = (Pointer) constructor.newInstance((long) device);
|
||||
devPointers.put(device, p);
|
||||
return p;
|
||||
Pointer pointer = DL4JClassLoading.createNewInstance(
|
||||
"org.nd4j.jita.allocator.pointers.CudaPointer",
|
||||
Pointer.class,
|
||||
new Class[] { long.class },
|
||||
(long) device);
|
||||
|
||||
devPointers.put(device, pointer);
|
||||
return pointer;
|
||||
} catch (Throwable t) {
|
||||
devPointers.put(device, null); //Stops attempting the failure again later...
|
||||
return null;
|
||||
|
@ -711,9 +715,9 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
|
|||
ModelInfo modelInfo = getModelInfo(model);
|
||||
int examplesThisMinibatch = 0;
|
||||
if (model instanceof MultiLayerNetwork) {
|
||||
examplesThisMinibatch = ((MultiLayerNetwork) model).batchSize();
|
||||
examplesThisMinibatch = model.batchSize();
|
||||
} else if (model instanceof ComputationGraph) {
|
||||
examplesThisMinibatch = ((ComputationGraph) model).batchSize();
|
||||
examplesThisMinibatch = model.batchSize();
|
||||
} else if (model instanceof Layer) {
|
||||
examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.ui.model.storage.mapdb;
|
|||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.core.storage.*;
|
||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||
|
@ -318,26 +319,18 @@ public class MapDBStatsStorage extends BaseCollectionStatsStorage {
|
|||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public T deserialize(@NonNull DataInput2 input, int available) throws IOException {
|
||||
int classIdx = input.readInt();
|
||||
String className = getClassForInt(classIdx);
|
||||
Class<?> clazz;
|
||||
try {
|
||||
clazz = Class.forName(className);
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw new RuntimeException(e); //Shouldn't normally happen...
|
||||
}
|
||||
Persistable p;
|
||||
try {
|
||||
p = (Persistable) clazz.newInstance();
|
||||
} catch (InstantiationException | IllegalAccessException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
int remainingLength = available - 4; //-4 for int class index
|
||||
|
||||
Persistable persistable = DL4JClassLoading.createNewInstance(className);
|
||||
|
||||
int remainingLength = available - 4; // -4 for int class index
|
||||
byte[] temp = new byte[remainingLength];
|
||||
input.readFully(temp);
|
||||
p.decode(temp);
|
||||
return (T) p;
|
||||
persistable.decode(temp);
|
||||
return (T) persistable;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -32,32 +32,42 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.common.config.DL4JSystemProperties;
|
||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||
import org.deeplearning4j.core.storage.StatsStorage;
|
||||
import org.deeplearning4j.core.storage.StatsStorageEvent;
|
||||
import org.deeplearning4j.core.storage.StatsStorageListener;
|
||||
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
||||
import org.deeplearning4j.common.config.DL4JSystemProperties;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
import org.deeplearning4j.ui.api.Route;
|
||||
import org.deeplearning4j.ui.api.UIModule;
|
||||
import org.deeplearning4j.ui.api.UIServer;
|
||||
import org.deeplearning4j.ui.i18n.I18NProvider;
|
||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
|
||||
import org.deeplearning4j.ui.module.SameDiffModule;
|
||||
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
|
||||
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
|
||||
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
|
||||
import org.deeplearning4j.ui.module.train.TrainModule;
|
||||
import org.deeplearning4j.ui.module.tsne.TsneModule;
|
||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
|
||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||
import org.nd4j.common.function.Function;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.ServiceLoader;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
@Slf4j
|
||||
|
@ -402,8 +412,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
|
||||
private void modulesViaServiceLoader(List<UIModule> uiModules) {
|
||||
|
||||
ServiceLoader<UIModule> sl = ServiceLoader.load(UIModule.class);
|
||||
ServiceLoader<UIModule> sl = DL4JClassLoading.loadService(UIModule.class);
|
||||
Iterator<UIModule> iter = sl.iterator();
|
||||
|
||||
if (!iter.hasNext()) {
|
||||
|
@ -411,19 +420,19 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
|
||||
while (iter.hasNext()) {
|
||||
UIModule m = iter.next();
|
||||
Class<?> c = m.getClass();
|
||||
UIModule module = iter.next();
|
||||
Class<?> moduleClass = module.getClass();
|
||||
boolean foundExisting = false;
|
||||
for (UIModule mExisting : uiModules) {
|
||||
if (mExisting.getClass() == c) {
|
||||
if (mExisting.getClass() == moduleClass) {
|
||||
foundExisting = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundExisting) {
|
||||
log.debug("Loaded UI module via service loader: {}", m.getClass());
|
||||
uiModules.add(m);
|
||||
log.debug("Loaded UI module via service loader: {}", module.getClass());
|
||||
uiModules.add(module);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.ui.i18n;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.ui.api.I18N;
|
||||
import org.deeplearning4j.ui.api.UIModule;
|
||||
|
||||
|
@ -100,13 +101,13 @@ public class DefaultI18N implements I18N {
|
|||
}
|
||||
|
||||
private synchronized void loadLanguages(){
|
||||
ServiceLoader<UIModule> sl = ServiceLoader.load(UIModule.class);
|
||||
ServiceLoader<UIModule> loadedModules = DL4JClassLoading.loadService(UIModule.class);
|
||||
|
||||
for(UIModule m : sl){
|
||||
List<I18NResource> resources = m.getInternationalizationResources();
|
||||
for(I18NResource r : resources){
|
||||
for (UIModule module : loadedModules){
|
||||
List<I18NResource> resources = module.getInternationalizationResources();
|
||||
for(I18NResource resource : resources){
|
||||
try {
|
||||
String path = r.getResource();
|
||||
String path = resource.getResource();
|
||||
int idxLast = path.lastIndexOf('.');
|
||||
if (idxLast < 0) {
|
||||
log.warn("Skipping language resource file: cannot infer language: {}", path);
|
||||
|
@ -116,9 +117,9 @@ public class DefaultI18N implements I18N {
|
|||
String langCode = path.substring(idxLast + 1).toLowerCase();
|
||||
Map<String, String> map = messagesByLanguage.computeIfAbsent(langCode, k -> new HashMap<>());
|
||||
|
||||
parseFile(r, map);
|
||||
parseFile(resource, map);
|
||||
} catch (Throwable t){
|
||||
log.warn("Error parsing UI I18N content file; skipping: {}", r.getResource(), t);
|
||||
log.warn("Error parsing UI I18N content file; skipping: {}", resource.getResource(), t);
|
||||
languageLoadingException = t;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import io.netty.handler.codec.http.HttpResponseStatus;
|
|||
import io.vertx.core.json.JsonObject;
|
||||
import io.vertx.ext.web.RoutingContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.core.storage.*;
|
||||
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
|
||||
import org.deeplearning4j.ui.api.HttpMethod;
|
||||
|
@ -154,9 +155,12 @@ public class RemoteReceiverModule implements UIModule {
|
|||
private StorageMetaData getMetaData(String dataClass, String content) {
|
||||
StorageMetaData meta;
|
||||
try {
|
||||
Class<?> c = Class.forName(dataClass);
|
||||
if (StorageMetaData.class.isAssignableFrom(c)) {
|
||||
meta = (StorageMetaData) c.newInstance();
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName(dataClass);
|
||||
if (StorageMetaData.class.isAssignableFrom(clazz)) {
|
||||
meta = clazz
|
||||
.asSubclass(StorageMetaData.class)
|
||||
.getDeclaredConstructor()
|
||||
.newInstance();
|
||||
} else {
|
||||
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
|
||||
StorageMetaData.class.getName());
|
||||
|
@ -179,11 +183,14 @@ public class RemoteReceiverModule implements UIModule {
|
|||
}
|
||||
|
||||
private Persistable getPersistable(String dataClass, String content) {
|
||||
Persistable p;
|
||||
Persistable persistable;
|
||||
try {
|
||||
Class<?> c = Class.forName(dataClass);
|
||||
if (Persistable.class.isAssignableFrom(c)) {
|
||||
p = (Persistable) c.newInstance();
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName(dataClass);
|
||||
if (Persistable.class.isAssignableFrom(clazz)) {
|
||||
persistable = clazz
|
||||
.asSubclass(Persistable.class)
|
||||
.getDeclaredConstructor()
|
||||
.newInstance();
|
||||
} else {
|
||||
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
|
||||
Persistable.class.getName());
|
||||
|
@ -196,12 +203,12 @@ public class RemoteReceiverModule implements UIModule {
|
|||
|
||||
try {
|
||||
byte[] bytes = DatatypeConverter.parseBase64Binary(content);
|
||||
p.decode(bytes);
|
||||
persistable.decode(bytes);
|
||||
} catch (Exception e) {
|
||||
log.warn("Skipping invalid remote data: exception encountered when deserializing data", e);
|
||||
return null;
|
||||
}
|
||||
|
||||
return p;
|
||||
return persistable;
|
||||
}
|
||||
}
|
|
@ -22,6 +22,7 @@ import lombok.NonNull;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
|
@ -127,7 +128,7 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
|
||||
for (ClassPath.ClassInfo c : info) {
|
||||
Class<?> clazz = Class.forName(c.getName());
|
||||
Class<?> clazz = DL4JClassLoading.loadClassByName(c.getName());
|
||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
|
||||
continue;
|
||||
|
||||
|
|
|
@ -1,3 +1,19 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) Eclipse Deeplearning4j Contributors 2020
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.common.config;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -6,6 +22,22 @@ import java.util.ServiceLoader;
|
|||
|
||||
/**
|
||||
* Global context for class-loading in ND4J.
|
||||
* <p>Use {@code ND4JClassLoading} to define classloader for ND4J only! To define classloader used by
|
||||
* {@code Deeplearning4j} use class {@link org.deeplearning4j.common.config.DL4JClassLoading}.
|
||||
*
|
||||
* <p>Usage:
|
||||
* <pre>{@code
|
||||
* public class Application {
|
||||
* static {
|
||||
* ND4JClassLoading.setNd4jClassloaderFromClass(Application.class);
|
||||
* }
|
||||
*
|
||||
* public static void main(String[] args) {
|
||||
* }
|
||||
* }
|
||||
* }</code>
|
||||
*
|
||||
* @see org.deeplearning4j.common.config.DL4JClassLoading
|
||||
*
|
||||
* @author Alexei KLENIN
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue