FEATURE#8712: add possibility to specify classloader for DL4J (#9115)
Signed-off-by: hosuaby <alexei.klenin@gmail.com>master
parent
2e000c84ac
commit
a722bd5a5b
|
@ -33,6 +33,12 @@
|
||||||
<artifactId>nd4j-common</artifactId>
|
<artifactId>nd4j-common</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
@ -43,5 +49,4 @@
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) Eclipse Deeplearning4j Contributors 2020
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.common.config;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.config.ND4JClassLoading;
|
||||||
|
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.ServiceLoader;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Global context for class-loading in DL4J.
|
||||||
|
* <p>Use {@code DL4JClassLoading} to define classloader for Deeplearning4j only! To define classloader used by
|
||||||
|
* {@code ND4J} use class {@link org.nd4j.common.config.ND4JClassLoading}.
|
||||||
|
*
|
||||||
|
* <p>Usage:
|
||||||
|
* <pre>{@code
|
||||||
|
* public class Application {
|
||||||
|
* static {
|
||||||
|
* DL4JClassLoading.setDl4jClassloaderFromClass(Application.class);
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* public static void main(String[] args) {
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* }</code>
|
||||||
|
*
|
||||||
|
* @see org.nd4j.common.config.ND4JClassLoading
|
||||||
|
*
|
||||||
|
* @author Alexei KLENIN
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class DL4JClassLoading {
|
||||||
|
private static ClassLoader dl4jClassloader = ND4JClassLoading.getNd4jClassloader();
|
||||||
|
|
||||||
|
private DL4JClassLoading() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ClassLoader getDl4jClassloader() {
|
||||||
|
return DL4JClassLoading.dl4jClassloader;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setDl4jClassloaderFromClass(Class<?> clazz) {
|
||||||
|
setDl4jClassloader(clazz.getClassLoader());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setDl4jClassloader(ClassLoader dl4jClassloader) {
|
||||||
|
DL4JClassLoading.dl4jClassloader = dl4jClassloader;
|
||||||
|
log.debug("Global class-loader for DL4J was changed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean classPresentOnClasspath(String className) {
|
||||||
|
return classPresentOnClasspath(className, dl4jClassloader);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean classPresentOnClasspath(String className, ClassLoader classLoader) {
|
||||||
|
return loadClassByName(className, false, classLoader) != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> Class<T> loadClassByName(String className) {
|
||||||
|
return loadClassByName(className, true, dl4jClassloader);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public static <T> Class<T> loadClassByName(String className, boolean initialize, ClassLoader classLoader) {
|
||||||
|
try {
|
||||||
|
return (Class<T>) Class.forName(className, initialize, classLoader);
|
||||||
|
} catch (ClassNotFoundException classNotFoundException) {
|
||||||
|
log.error(String.format("Cannot find class [%s] of provided class-loader.", className));
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> T createNewInstance(String className, Object... args) {
|
||||||
|
return createNewInstance(className, Object.class, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> T createNewInstance(String className, Class<? super T> superclass) {
|
||||||
|
return createNewInstance(className, superclass, new Class<?>[]{}, new Object[]{});
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> T createNewInstance(String className, Class<? super T> superclass, Object... args) {
|
||||||
|
Class<?>[] parameterTypes = new Class<?>[args.length];
|
||||||
|
for (int i = 0; i < args.length; i++) {
|
||||||
|
Object arg = args[i];
|
||||||
|
Objects.requireNonNull(arg);
|
||||||
|
parameterTypes[i] = arg.getClass();
|
||||||
|
}
|
||||||
|
|
||||||
|
return createNewInstance(className, superclass, parameterTypes, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> T createNewInstance(
|
||||||
|
String className,
|
||||||
|
Class<? super T> superclass,
|
||||||
|
Class<?>[] parameterTypes,
|
||||||
|
Object... args) {
|
||||||
|
try {
|
||||||
|
return (T) DL4JClassLoading
|
||||||
|
.loadClassByName(className)
|
||||||
|
.asSubclass(superclass)
|
||||||
|
.getDeclaredConstructor(parameterTypes)
|
||||||
|
.newInstance(args);
|
||||||
|
} catch (InstantiationException | IllegalAccessException | InvocationTargetException
|
||||||
|
| NoSuchMethodException instantiationException) {
|
||||||
|
log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException);
|
||||||
|
throw new RuntimeException(instantiationException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass) {
|
||||||
|
return loadService(serviceClass, dl4jClassloader);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <S> ServiceLoader<S> loadService(Class<S> serviceClass, ClassLoader classLoader) {
|
||||||
|
return ServiceLoader.load(serviceClass, classLoader);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
package org.deeplearning4j.common.config;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
|
||||||
|
import org.deeplearning4j.common.config.dummies.TestAbstract;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class DL4JClassLoadingTest {
|
||||||
|
private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies.";
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateNewInstance_constructorWithoutArguments() {
|
||||||
|
|
||||||
|
/* Given */
|
||||||
|
String className = PACKAGE_PREFIX + "TestDummy";
|
||||||
|
|
||||||
|
/* When */
|
||||||
|
Object instance = DL4JClassLoading.createNewInstance(className);
|
||||||
|
|
||||||
|
/* Then */
|
||||||
|
assertNotNull(instance);
|
||||||
|
assertEquals(className, instance.getClass().getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() {
|
||||||
|
|
||||||
|
/* Given */
|
||||||
|
String className = PACKAGE_PREFIX + "TestColor";
|
||||||
|
|
||||||
|
/* When */
|
||||||
|
TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white");
|
||||||
|
|
||||||
|
/* Then */
|
||||||
|
assertNotNull(instance);
|
||||||
|
assertEquals(className, instance.getClass().getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() {
|
||||||
|
|
||||||
|
/* Given */
|
||||||
|
String colorClassName = PACKAGE_PREFIX + "TestColor";
|
||||||
|
String rectangleClassName = PACKAGE_PREFIX + "TestRectangle";
|
||||||
|
|
||||||
|
/* When */
|
||||||
|
TestAbstract color = DL4JClassLoading.createNewInstance(
|
||||||
|
colorClassName,
|
||||||
|
Object.class,
|
||||||
|
new Class<?>[]{ int.class, int.class, int.class },
|
||||||
|
45, 175, 200);
|
||||||
|
|
||||||
|
TestAbstract rectangle = DL4JClassLoading.createNewInstance(
|
||||||
|
rectangleClassName,
|
||||||
|
Object.class,
|
||||||
|
new Class<?>[]{ int.class, int.class, TestAbstract.class },
|
||||||
|
10, 15, color);
|
||||||
|
|
||||||
|
/* Then */
|
||||||
|
assertNotNull(color);
|
||||||
|
assertEquals(colorClassName, color.getClass().getName());
|
||||||
|
|
||||||
|
assertNotNull(rectangle);
|
||||||
|
assertEquals(rectangleClassName, rectangle.getClass().getName());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
package org.deeplearning4j.common.config.dummies;
|
||||||
|
|
||||||
|
public abstract class TestAbstract {
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
package org.deeplearning4j.common.config.dummies;
|
||||||
|
|
||||||
|
public class TestColor extends TestAbstract {
|
||||||
|
public TestColor(String color) {
|
||||||
|
}
|
||||||
|
|
||||||
|
public TestColor(int r, int g, int b) {
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package org.deeplearning4j.common.config.dummies;
|
||||||
|
|
||||||
|
public class TestDummy {
|
||||||
|
public TestDummy() {
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package org.deeplearning4j.common.config.dummies;
|
||||||
|
|
||||||
|
public class TestRectangle extends TestAbstract {
|
||||||
|
public TestRectangle(int width, int height, TestAbstract color) {
|
||||||
|
}
|
||||||
|
}
|
|
@ -102,7 +102,6 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private class ReaderThread<T> extends Thread implements Runnable {
|
private class ReaderThread<T> extends Thread implements Runnable {
|
||||||
private BlockingQueue<T> buffer;
|
private BlockingQueue<T> buffer;
|
||||||
private Iterator<T> iterator;
|
private Iterator<T> iterator;
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j;
|
||||||
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||||
|
@ -66,11 +67,11 @@ public class LayerHelperValidationUtil {
|
||||||
|
|
||||||
public static void disableCppHelpers(){
|
public static void disableCppHelpers(){
|
||||||
try {
|
try {
|
||||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||||
Method m = c.getMethod("getInstance");
|
Method getInstance = clazz.getMethod("getInstance");
|
||||||
Object instance = m.invoke(null);
|
Object instance = getInstance.invoke(null);
|
||||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
|
||||||
m2.invoke(instance, false);
|
allowHelpers.invoke(instance, false);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
throw new RuntimeException(t);
|
throw new RuntimeException(t);
|
||||||
}
|
}
|
||||||
|
@ -78,11 +79,11 @@ public class LayerHelperValidationUtil {
|
||||||
|
|
||||||
public static void enableCppHelpers(){
|
public static void enableCppHelpers(){
|
||||||
try {
|
try {
|
||||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||||
Method m = c.getMethod("getInstance");
|
Method getInstance = clazz.getMethod("getInstance");
|
||||||
Object instance = m.invoke(null);
|
Object instance = getInstance.invoke(null);
|
||||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
Method allowHelpers = clazz.getMethod("allowHelpers", boolean.class);
|
||||||
m2.invoke(instance, true);
|
allowHelpers.invoke(instance, true);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
throw new RuntimeException(t);
|
throw new RuntimeException(t);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,28 +16,96 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.dtypes;
|
package org.deeplearning4j.nn.dtypes;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
import static org.junit.Assert.assertEquals;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.*;
|
import static org.junit.Assert.assertTrue;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
|
||||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
|
||||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.conf.*;
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.dropout.AlphaDropout;
|
import org.deeplearning4j.nn.conf.dropout.AlphaDropout;
|
||||||
import org.deeplearning4j.nn.conf.dropout.GaussianDropout;
|
import org.deeplearning4j.nn.conf.dropout.GaussianDropout;
|
||||||
import org.deeplearning4j.nn.conf.dropout.GaussianNoise;
|
import org.deeplearning4j.nn.conf.dropout.GaussianNoise;
|
||||||
import org.deeplearning4j.nn.conf.dropout.SpatialDropout;
|
import org.deeplearning4j.nn.conf.dropout.SpatialDropout;
|
||||||
import org.deeplearning4j.nn.conf.graph.*;
|
import org.deeplearning4j.nn.conf.graph.AttentionVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.FrozenVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.L2NormalizeVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.L2Vertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.PoolHelperVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.ReshapeVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.ScaleVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.ShiftVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.StackVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.UnstackVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
|
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
|
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex;
|
import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Cnn3DLossLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LearnedSelfAttentionLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.PReLULayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Pooling1D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Pooling2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.RecurrentAttentionLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.SelfAttentionLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.SpaceToBatchLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.SpaceToDepthLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Upsampling1D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Upsampling3D;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
|
||||||
|
@ -49,16 +117,24 @@ import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.util.IdentityLayer;
|
import org.deeplearning4j.nn.layers.util.IdentityLayer;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
|
@ -77,12 +153,17 @@ import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.learning.config.NoOp;
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
||||||
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
|
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||||
|
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.lang.reflect.Modifier;
|
import java.lang.reflect.Modifier;
|
||||||
import java.util.*;
|
import java.util.Arrays;
|
||||||
|
import java.util.HashSet;
|
||||||
import static org.junit.Assert.*;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DTypeTests extends BaseDL4JTest {
|
public class DTypeTests extends BaseDL4JTest {
|
||||||
|
@ -120,20 +201,17 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
Set<Class<?>> preprocClasses = new HashSet<>();
|
Set<Class<?>> preprocClasses = new HashSet<>();
|
||||||
Set<Class<?>> vertexClasses = new HashSet<>();
|
Set<Class<?>> vertexClasses = new HashSet<>();
|
||||||
for (ClassPath.ClassInfo ci : info) {
|
for (ClassPath.ClassInfo ci : info) {
|
||||||
Class<?> clazz;
|
Class<?> clazz = DL4JClassLoading.loadClassByName(ci.getName());
|
||||||
try {
|
|
||||||
clazz = Class.forName(ci.getName());
|
|
||||||
} catch (ClassNotFoundException e) {
|
|
||||||
//Should never happen as this was found on the classpath
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
|
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) {
|
||||||
|
// Skip TFOpLayer here - dtype depends on imported model dtype
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (clazz.getName().toLowerCase().contains("custom") || clazz.getName().contains("samediff.testlayers")
|
if (clazz.getName().toLowerCase().contains("custom")
|
||||||
|| clazz.getName().toLowerCase().contains("test") || ignoreClasses.contains(clazz)) {
|
|| clazz.getName().contains("samediff.testlayers")
|
||||||
|
|| clazz.getName().toLowerCase().contains("test")
|
||||||
|
|| ignoreClasses.contains(clazz)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.recurrent;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -167,7 +168,7 @@ public class GravesLSTMTest extends BaseDL4JTest {
|
||||||
actHelper.setAccessible(true);
|
actHelper.setAccessible(true);
|
||||||
|
|
||||||
//Call activateHelper with both forBackprop == true, and forBackprop == false and compare
|
//Call activateHelper with both forBackprop == true, and forBackprop == false and compare
|
||||||
Class<?> innerClass = Class.forName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
|
Class<?> innerClass = DL4JClassLoading.loadClassByName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn");
|
||||||
|
|
||||||
Object oFalse = actHelper.invoke(lstm, false, null, null, false, LayerWorkspaceMgr.noWorkspacesImmutable()); //GravesLSTM.FwdPassReturn object; want fwdPassOutput INDArray
|
Object oFalse = actHelper.invoke(lstm, false, null, null, false, LayerWorkspaceMgr.noWorkspacesImmutable()); //GravesLSTM.FwdPassReturn object; want fwdPassOutput INDArray
|
||||||
Object oTrue = actHelper.invoke(lstm, false, null, null, true, LayerWorkspaceMgr.noWorkspacesImmutable()); //want fwdPassOutputAsArrays object
|
Object oTrue = actHelper.invoke(lstm, false, null, null, true, LayerWorkspaceMgr.noWorkspacesImmutable()); //want fwdPassOutputAsArrays object
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.AbstractLayer;
|
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||||
|
@ -110,7 +111,7 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||||
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
|
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
|
||||||
byte[] graphBytes = serialized.toByteArray();
|
byte[] graphBytes = serialized.toByteArray();
|
||||||
|
|
||||||
ServiceLoader<TFGraphRunnerService> sl = ServiceLoader.load(TFGraphRunnerService.class);
|
ServiceLoader<TFGraphRunnerService> sl = DL4JClassLoading.loadService(TFGraphRunnerService.class);
|
||||||
Iterator<TFGraphRunnerService> iter = sl.iterator();
|
Iterator<TFGraphRunnerService> iter = sl.iterator();
|
||||||
if (!iter.hasNext()){
|
if (!iter.hasNext()){
|
||||||
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
|
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
|
||||||
|
|
|
@ -61,14 +61,12 @@ public abstract class Model {
|
||||||
/**
|
/**
|
||||||
* 模型读取
|
* 模型读取
|
||||||
*
|
*
|
||||||
* @param path
|
|
||||||
* @return
|
|
||||||
* @return
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
*/
|
||||||
public static Model load(Class<? extends Model> c, InputStream is) throws Exception {
|
public static Model load(Class<? extends Model> modelClass, InputStream inputStream) throws Exception {
|
||||||
Model model = c.newInstance();
|
return modelClass
|
||||||
return model.loadModel(is);
|
.getDeclaredConstructor()
|
||||||
|
.newInstance()
|
||||||
|
.loadModel(inputStream);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -5,6 +5,7 @@ import org.ansj.dic.impl.Jar2Stream;
|
||||||
import org.ansj.dic.impl.Jdbc2Stream;
|
import org.ansj.dic.impl.Jdbc2Stream;
|
||||||
import org.ansj.dic.impl.Url2Stream;
|
import org.ansj.dic.impl.Url2Stream;
|
||||||
import org.ansj.exception.LibraryException;
|
import org.ansj.exception.LibraryException;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
|
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
@ -25,7 +26,8 @@ public abstract class PathToStream {
|
||||||
} else if (path.startsWith("jar://")) {
|
} else if (path.startsWith("jar://")) {
|
||||||
return new Jar2Stream().toStream(path);
|
return new Jar2Stream().toStream(path);
|
||||||
} else if (path.startsWith("class://")) {
|
} else if (path.startsWith("class://")) {
|
||||||
((PathToStream) Class.forName(path.substring(8).split("\\|")[0]).newInstance()).toStream(path);
|
// Probably unused
|
||||||
|
return loadClass(path);
|
||||||
} else if (path.startsWith("http://") || path.startsWith("https://")) {
|
} else if (path.startsWith("http://") || path.startsWith("https://")) {
|
||||||
return new Url2Stream().toStream(path);
|
return new Url2Stream().toStream(path);
|
||||||
} else {
|
} else {
|
||||||
|
@ -34,9 +36,17 @@ public abstract class PathToStream {
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new LibraryException(e);
|
throw new LibraryException(e);
|
||||||
}
|
}
|
||||||
throw new LibraryException("not find method type in path " + path);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract InputStream toStream(String path);
|
public abstract InputStream toStream(String path);
|
||||||
|
|
||||||
|
static InputStream loadClass(String path) {
|
||||||
|
String className = path
|
||||||
|
.substring("class://".length())
|
||||||
|
.split("\\|")[0];
|
||||||
|
|
||||||
|
return DL4JClassLoading
|
||||||
|
.createNewInstance(className, PathToStream.class)
|
||||||
|
.toStream(path);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package org.ansj.dic.impl;
|
||||||
import org.ansj.dic.DicReader;
|
import org.ansj.dic.DicReader;
|
||||||
import org.ansj.dic.PathToStream;
|
import org.ansj.dic.PathToStream;
|
||||||
import org.ansj.exception.LibraryException;
|
import org.ansj.exception.LibraryException;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
|
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
@ -17,12 +18,16 @@ public class Jar2Stream extends PathToStream {
|
||||||
@Override
|
@Override
|
||||||
public InputStream toStream(String path) {
|
public InputStream toStream(String path) {
|
||||||
if (path.contains("|")) {
|
if (path.contains("|")) {
|
||||||
String[] split = path.split("\\|");
|
String[] tokens = path.split("\\|");
|
||||||
try {
|
String className = tokens[0].substring(6);
|
||||||
return Class.forName(split[0].substring(6)).getResourceAsStream(split[1].trim());
|
String resourceName = tokens[1].trim();
|
||||||
} catch (ClassNotFoundException e) {
|
|
||||||
throw new LibraryException(e);
|
Class<Object> resourceClass = DL4JClassLoading.loadClassByName(className);
|
||||||
|
if (resourceClass == null) {
|
||||||
|
throw new LibraryException(String.format("Class '%s' was not found.", className));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return resourceClass.getResourceAsStream(resourceName);
|
||||||
} else {
|
} else {
|
||||||
return DicReader.getInputStream(path.substring(6));
|
return DicReader.getInputStream(path.substring(6));
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.ansj.dic.impl;
|
||||||
|
|
||||||
import org.ansj.dic.PathToStream;
|
import org.ansj.dic.PathToStream;
|
||||||
import org.ansj.exception.LibraryException;
|
import org.ansj.exception.LibraryException;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
|
@ -22,20 +23,26 @@ public class Jdbc2Stream extends PathToStream {
|
||||||
|
|
||||||
private static final byte[] LINE = "\n".getBytes();
|
private static final byte[] LINE = "\n".getBytes();
|
||||||
|
|
||||||
|
private static final String[] JDBC_DRIVERS = {
|
||||||
|
"org.h2.Driver",
|
||||||
|
"com.ibm.db2.jcc.DB2Driver",
|
||||||
|
"org.hsqldb.jdbcDriver",
|
||||||
|
"org.gjt.mm.mysql.Driver",
|
||||||
|
"oracle.jdbc.OracleDriver",
|
||||||
|
"org.postgresql.Driver",
|
||||||
|
"net.sourceforge.jtds.jdbc.Driver",
|
||||||
|
"com.microsoft.sqlserver.jdbc.SQLServerDriver",
|
||||||
|
"org.sqlite.JDBC",
|
||||||
|
"com.mysql.jdbc.Driver"
|
||||||
|
};
|
||||||
|
|
||||||
static {
|
static {
|
||||||
String[] drivers = {"org.h2.Driver", "com.ibm.db2.jcc.DB2Driver", "org.hsqldb.jdbcDriver",
|
loadJdbcDrivers();
|
||||||
"org.gjt.mm.mysql.Driver", "oracle.jdbc.OracleDriver", "org.postgresql.Driver",
|
|
||||||
"net.sourceforge.jtds.jdbc.Driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver",
|
|
||||||
"org.sqlite.JDBC", "com.mysql.jdbc.Driver"};
|
|
||||||
for (String driverClassName : drivers) {
|
|
||||||
try {
|
|
||||||
try {
|
|
||||||
Thread.currentThread().getContextClassLoader().loadClass(driverClassName);
|
|
||||||
} catch (ClassNotFoundException e) {
|
|
||||||
Class.forName(driverClassName);
|
|
||||||
}
|
|
||||||
} catch (Throwable e) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void loadJdbcDrivers() {
|
||||||
|
for (String driverClassName : JDBC_DRIVERS) {
|
||||||
|
DL4JClassLoading.loadClassByName(driverClassName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,44 +17,20 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models.embeddings.loader;
|
package org.deeplearning4j.models.embeddings.loader;
|
||||||
|
|
||||||
import java.io.BufferedInputStream;
|
import lombok.AllArgsConstructor;
|
||||||
import java.io.BufferedOutputStream;
|
import lombok.Data;
|
||||||
import java.io.BufferedReader;
|
import lombok.NoArgsConstructor;
|
||||||
import java.io.BufferedWriter;
|
import lombok.NonNull;
|
||||||
import java.io.ByteArrayInputStream;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import java.io.DataInputStream;
|
import lombok.val;
|
||||||
import java.io.DataOutputStream;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FileNotFoundException;
|
|
||||||
import java.io.FileOutputStream;
|
|
||||||
import java.io.FileReader;
|
|
||||||
import java.io.FileWriter;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.InputStreamReader;
|
|
||||||
import java.io.ObjectInputStream;
|
|
||||||
import java.io.ObjectOutputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.io.OutputStreamWriter;
|
|
||||||
import java.io.PrintWriter;
|
|
||||||
import java.io.UnsupportedEncodingException;
|
|
||||||
import java.nio.charset.StandardCharsets;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.zip.GZIPInputStream;
|
|
||||||
import java.util.zip.ZipEntry;
|
|
||||||
import java.util.zip.ZipFile;
|
|
||||||
import java.util.zip.ZipInputStream;
|
|
||||||
import java.util.zip.ZipOutputStream;
|
|
||||||
|
|
||||||
import org.apache.commons.codec.binary.Base64;
|
import org.apache.commons.codec.binary.Base64;
|
||||||
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.apache.commons.io.LineIterator;
|
import org.apache.commons.io.LineIterator;
|
||||||
import org.apache.commons.io.output.CloseShieldOutputStream;
|
import org.apache.commons.io.output.CloseShieldOutputStream;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
|
@ -94,12 +70,37 @@ import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
import org.nd4j.storage.CompressedRamStorage;
|
import org.nd4j.storage.CompressedRamStorage;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import java.io.BufferedInputStream;
|
||||||
import lombok.Data;
|
import java.io.BufferedOutputStream;
|
||||||
import lombok.NoArgsConstructor;
|
import java.io.BufferedReader;
|
||||||
import lombok.NonNull;
|
import java.io.BufferedWriter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import java.io.ByteArrayInputStream;
|
||||||
import lombok.val;
|
import java.io.DataInputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.FileReader;
|
||||||
|
import java.io.FileWriter;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.io.OutputStreamWriter;
|
||||||
|
import java.io.PrintWriter;
|
||||||
|
import java.io.UnsupportedEncodingException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.zip.GZIPInputStream;
|
||||||
|
import java.util.zip.ZipEntry;
|
||||||
|
import java.util.zip.ZipFile;
|
||||||
|
import java.util.zip.ZipInputStream;
|
||||||
|
import java.util.zip.ZipOutputStream;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is utility class, providing various methods for WordVectors serialization
|
* This is utility class, providing various methods for WordVectors serialization
|
||||||
|
@ -2676,26 +2677,23 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
|
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
|
||||||
if (configuration == null)
|
if (configuration == null) {
|
||||||
return null;
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
if (configuration.getTokenizerFactory() != null && !configuration.getTokenizerFactory().isEmpty()) {
|
String tokenizerFactoryClassName = configuration.getTokenizerFactory();
|
||||||
try {
|
if (StringUtils.isNotEmpty(tokenizerFactoryClassName)) {
|
||||||
TokenizerFactory factory =
|
TokenizerFactory factory = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
|
||||||
(TokenizerFactory) Class.forName(configuration.getTokenizerFactory()).newInstance();
|
|
||||||
|
|
||||||
if (configuration.getTokenPreProcessor() != null && !configuration.getTokenPreProcessor().isEmpty()) {
|
String tokenPreProcessorClassName = configuration.getTokenPreProcessor();
|
||||||
TokenPreProcess preProcessor =
|
if (StringUtils.isNotEmpty(tokenPreProcessorClassName)) {
|
||||||
(TokenPreProcess) Class.forName(configuration.getTokenPreProcessor()).newInstance();
|
TokenPreProcess preProcessor = DL4JClassLoading.createNewInstance(tokenizerFactoryClassName);
|
||||||
factory.setTokenPreProcessor(preProcessor);
|
factory.setTokenPreProcessor(preProcessor);
|
||||||
}
|
}
|
||||||
|
|
||||||
return factory;
|
return factory;
|
||||||
|
}
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Can't instantiate saved TokenizerFactory: {}", configuration.getTokenizerFactory());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models.sequencevectors;
|
package org.deeplearning4j.models.sequencevectors;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -494,15 +496,17 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
||||||
this.useHierarchicSoftmax = configuration.isUseHierarchicSoftmax();
|
this.useHierarchicSoftmax = configuration.isUseHierarchicSoftmax();
|
||||||
this.preciseMode = configuration.isPreciseMode();
|
this.preciseMode = configuration.isPreciseMode();
|
||||||
|
|
||||||
if (configuration.getModelUtils() != null && !configuration.getModelUtils().isEmpty()) {
|
String modelUtilsClassName = configuration.getModelUtils();
|
||||||
|
if (StringUtils.isNotEmpty(modelUtilsClassName)) {
|
||||||
try {
|
try {
|
||||||
this.modelUtils = (ModelUtils<T>) Class.forName(configuration.getModelUtils()).newInstance();
|
this.modelUtils = DL4JClassLoading.createNewInstance(modelUtilsClassName);
|
||||||
} catch (Exception e) {
|
} catch (Exception instantiationException) {
|
||||||
log.error("Got {} trying to instantiate ModelUtils, falling back to BasicModelUtils instead");
|
log.error(
|
||||||
|
"Got '{}' trying to instantiate ModelUtils, falling back to BasicModelUtils instead",
|
||||||
|
instantiationException.getMessage(),
|
||||||
|
instantiationException);
|
||||||
this.modelUtils = new BasicModelUtils<>();
|
this.modelUtils = new BasicModelUtils<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (configuration.getElementsLearningAlgorithm() != null
|
if (configuration.getElementsLearningAlgorithm() != null
|
||||||
|
@ -551,12 +555,7 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder<T> sequenceLearningAlgorithm(@NonNull String algoName) {
|
public Builder<T> sequenceLearningAlgorithm(@NonNull String algoName) {
|
||||||
try {
|
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
|
||||||
Class clazz = Class.forName(algoName);
|
|
||||||
sequenceLearningAlgorithm = (SequenceLearningAlgorithm<T>) clazz.newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -578,13 +577,9 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder<T> elementsLearningAlgorithm(@NonNull String algoName) {
|
public Builder<T> elementsLearningAlgorithm(@NonNull String algoName) {
|
||||||
try {
|
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(algoName);
|
||||||
Class clazz = Class.forName(algoName);
|
|
||||||
elementsLearningAlgorithm = (ElementsLearningAlgorithm<T>) clazz.newInstance();
|
|
||||||
this.configuration.setElementsLearningAlgorithm(elementsLearningAlgorithm.getClass().getCanonicalName());
|
this.configuration.setElementsLearningAlgorithm(elementsLearningAlgorithm.getClass().getCanonicalName());
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -943,31 +938,23 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
||||||
.lr(learningRate).seed(seed).build();
|
.lr(learningRate).seed(seed).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.configuration.getElementsLearningAlgorithm() != null) {
|
String elementsLearningAlgorithm = this.configuration.getElementsLearningAlgorithm();
|
||||||
try {
|
if (StringUtils.isNotEmpty(elementsLearningAlgorithm)) {
|
||||||
elementsLearningAlgorithm = (ElementsLearningAlgorithm<T>) Class
|
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||||
.forName(this.configuration.getElementsLearningAlgorithm()).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.configuration.getSequenceLearningAlgorithm() != null) {
|
String sequenceLearningAlgorithm = this.configuration.getSequenceLearningAlgorithm();
|
||||||
try {
|
if (StringUtils.isNotEmpty(sequenceLearningAlgorithm)) {
|
||||||
sequenceLearningAlgorithm = (SequenceLearningAlgorithm<T>) Class
|
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
|
||||||
.forName(this.configuration.getSequenceLearningAlgorithm()).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (trainElementsVectors && elementsLearningAlgorithm == null) {
|
if (trainElementsVectors && this.elementsLearningAlgorithm == null) {
|
||||||
// create default implementation of ElementsLearningAlgorithm
|
// create default implementation of ElementsLearningAlgorithm
|
||||||
elementsLearningAlgorithm = new SkipGram<>();
|
this.elementsLearningAlgorithm = new SkipGram<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (trainSequenceVectors && sequenceLearningAlgorithm == null) {
|
if (trainSequenceVectors && this.sequenceLearningAlgorithm == null) {
|
||||||
sequenceLearningAlgorithm = new DBOW<>();
|
this.sequenceLearningAlgorithm = new DBOW<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.modelUtils.init(lookupTable);
|
this.modelUtils.init(lookupTable);
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -132,25 +133,20 @@ public class Dropout implements IDropout {
|
||||||
*/
|
*/
|
||||||
protected void initializeHelper(DataType dataType){
|
protected void initializeHelper(DataType dataType){
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
|
||||||
if("CUDA".equalsIgnoreCase(backend)) {
|
if("CUDA".equalsIgnoreCase(backend)) {
|
||||||
try {
|
helper = DL4JClassLoading.createNewInstance(
|
||||||
helper = Class.forName("org.deeplearning4j.cuda.dropout.CudnnDropoutHelper")
|
"org.deeplearning4j.cuda.dropout.CudnnDropoutHelper",
|
||||||
.asSubclass(DropoutHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
DropoutHelper.class,
|
||||||
|
dataType);
|
||||||
log.debug("CudnnDropoutHelper successfully initialized");
|
log.debug("CudnnDropoutHelper successfully initialized");
|
||||||
if (!helper.checkSupported()) {
|
if (!helper.checkSupported()) {
|
||||||
helper = null;
|
helper = null;
|
||||||
}
|
}
|
||||||
} catch (Throwable t) {
|
|
||||||
if (!(t instanceof ClassNotFoundException)) {
|
|
||||||
log.warn("Could not initialize CudnnDropoutHelper", t);
|
|
||||||
}
|
|
||||||
//Unlike other layers, don't warn here about CuDNN not found - if the user has any other layers that can
|
|
||||||
// benefit from them cudnn, they will get a warning from those
|
|
||||||
}
|
|
||||||
}
|
|
||||||
initializedHelper = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
initializedHelper = true;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
|
@ -43,6 +43,7 @@ import org.nd4j.shade.jackson.databind.deser.std.StdDeserializer;
|
||||||
import org.nd4j.shade.jackson.databind.node.ObjectNode;
|
import org.nd4j.shade.jackson.databind.node.ObjectNode;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -268,14 +269,17 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
|
||||||
|
|
||||||
//Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" : <object>
|
//Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" : <object>
|
||||||
protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){
|
protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){
|
||||||
|
|
||||||
if(baseLayer.getActivationFn() == null && on.has("activationFunction")){
|
if(baseLayer.getActivationFn() == null && on.has("activationFunction")){
|
||||||
String afn = on.get("activationFunction").asText();
|
String afn = on.get("activationFunction").asText();
|
||||||
IActivation a = null;
|
IActivation a = null;
|
||||||
try {
|
try {
|
||||||
a = getMap().get(afn.toLowerCase()).newInstance();
|
a = getMap()
|
||||||
} catch (InstantiationException | IllegalAccessException e){
|
.get(afn.toLowerCase())
|
||||||
//Ignore
|
.getDeclaredConstructor()
|
||||||
|
.newInstance();
|
||||||
|
} catch (InstantiationException | IllegalAccessException | NoSuchMethodException
|
||||||
|
| InvocationTargetException instantiationException){
|
||||||
|
log.error(instantiationException.getMessage());
|
||||||
}
|
}
|
||||||
baseLayer.setActivationFn(a);
|
baseLayer.setActivationFn(a);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution;
|
||||||
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
@ -73,26 +74,19 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
void initializeHelper() {
|
void initializeHelper() {
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
if("CUDA".equalsIgnoreCase(backend)) {
|
if("CUDA".equalsIgnoreCase(backend)) {
|
||||||
try {
|
helper = DL4JClassLoading.createNewInstance(
|
||||||
helper = Class.forName("org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper")
|
"org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper",
|
||||||
.asSubclass(ConvolutionHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
ConvolutionHelper.class,
|
||||||
|
dataType);
|
||||||
log.debug("CudnnConvolutionHelper successfully initialized");
|
log.debug("CudnnConvolutionHelper successfully initialized");
|
||||||
if (!helper.checkSupported()) {
|
if (!helper.checkSupported()) {
|
||||||
helper = null;
|
helper = null;
|
||||||
}
|
}
|
||||||
} catch (Throwable t) {
|
|
||||||
if (!(t instanceof ClassNotFoundException)) {
|
|
||||||
log.warn("Could not initialize CudnnConvolutionHelper", t);
|
|
||||||
} else {
|
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
|
||||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if("CPU".equalsIgnoreCase(backend)){
|
} else if("CPU".equalsIgnoreCase(backend)){
|
||||||
helper = new MKLDNNConvHelper(dataType);
|
helper = new MKLDNNConvHelper(dataType);
|
||||||
log.trace("Created MKLDNNConvHelper, layer {}", layerConf().getLayerName());
|
log.trace("Created MKLDNNConvHelper, layer {}", layerConf().getLayerName());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (helper != null && !helper.checkSupported()) {
|
if (helper != null && !helper.checkSupported()) {
|
||||||
log.debug("Removed helper {} as not supported", helper.getClass());
|
log.debug("Removed helper {} as not supported", helper.getClass());
|
||||||
helper = null;
|
helper = null;
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.layers.convolution.subsampling;
|
package org.deeplearning4j.nn.layers.convolution.subsampling;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
@ -30,17 +31,15 @@ import org.deeplearning4j.nn.layers.mkldnn.MKLDNNSubsamplingHelper;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.nd4j.common.util.OneTimeLogger;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subsampling layer.
|
* Subsampling layer.
|
||||||
*
|
*
|
||||||
|
@ -64,27 +63,21 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
|
|
||||||
void initializeHelper() {
|
void initializeHelper() {
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
|
||||||
if("CUDA".equalsIgnoreCase(backend)) {
|
if("CUDA".equalsIgnoreCase(backend)) {
|
||||||
try {
|
helper = DL4JClassLoading.createNewInstance(
|
||||||
helper = Class.forName("org.deeplearning4j.cuda.convolution.subsampling.CudnnSubsamplingHelper")
|
"org.deeplearning4j.cuda.convolution.subsampling.CudnnSubsamplingHelper",
|
||||||
.asSubclass(SubsamplingHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
SubsamplingHelper.class,
|
||||||
|
dataType);
|
||||||
log.debug("CudnnSubsamplingHelper successfully initialized");
|
log.debug("CudnnSubsamplingHelper successfully initialized");
|
||||||
if (!helper.checkSupported()) {
|
if (!helper.checkSupported()) {
|
||||||
helper = null;
|
helper = null;
|
||||||
}
|
}
|
||||||
} catch (Throwable t) {
|
|
||||||
if (!(t instanceof ClassNotFoundException)) {
|
|
||||||
log.warn("Could not initialize CudnnSubsamplingHelper", t);
|
|
||||||
} else {
|
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
|
||||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if("CPU".equalsIgnoreCase(backend) ){
|
} else if("CPU".equalsIgnoreCase(backend) ){
|
||||||
helper = new MKLDNNSubsamplingHelper(dataType);
|
helper = new MKLDNNSubsamplingHelper(dataType);
|
||||||
log.trace("Created MKL-DNN helper: MKLDNNSubsamplingHelper, layer {}", layerConf().getLayerName());
|
log.trace("Created MKL-DNN helper: MKLDNNSubsamplingHelper, layer {}", layerConf().getLayerName());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (helper != null && !helper.checkSupported()) {
|
if (helper != null && !helper.checkSupported()) {
|
||||||
log.debug("Removed helper {} as not supported", helper.getClass());
|
log.debug("Removed helper {} as not supported", helper.getClass());
|
||||||
helper = null;
|
helper = null;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.mkldnn;
|
package org.deeplearning4j.nn.layers.mkldnn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
@ -47,12 +48,11 @@ public class BaseMKLDNNHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
try{
|
try{
|
||||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
Class<?> clazz = DL4JClassLoading.loadClassByName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||||
Method m = c.getMethod("getInstance");
|
Method getInstance = clazz.getMethod("getInstance");
|
||||||
Object instance = m.invoke(null);
|
Object instance = getInstance.invoke(null);
|
||||||
Method m2 = c.getMethod("isUseMKLDNN");
|
Method isUseMKLDNNMethod = clazz.getMethod("isUseMKLDNN");
|
||||||
boolean b = (Boolean)m2.invoke(instance);
|
return (boolean) isUseMKLDNNMethod.invoke(instance);
|
||||||
return b;
|
|
||||||
} catch (Throwable t ){
|
} catch (Throwable t ){
|
||||||
FAILED_CHECK = new AtomicBoolean(true);
|
FAILED_CHECK = new AtomicBoolean(true);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
|
@ -75,24 +76,18 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
|
|
||||||
void initializeHelper() {
|
void initializeHelper() {
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
|
||||||
if ("CUDA".equalsIgnoreCase(backend)) {
|
if ("CUDA".equalsIgnoreCase(backend)) {
|
||||||
try {
|
helper = DL4JClassLoading.createNewInstance(
|
||||||
helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper")
|
"org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper",
|
||||||
.asSubclass(BatchNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
BatchNormalizationHelper.class,
|
||||||
|
dataType);
|
||||||
log.debug("CudnnBatchNormalizationHelper successfully initialized");
|
log.debug("CudnnBatchNormalizationHelper successfully initialized");
|
||||||
} catch (Throwable t) {
|
|
||||||
if (!(t instanceof ClassNotFoundException)) {
|
|
||||||
log.warn("Could not initialize CudnnBatchNormalizationHelper", t);
|
|
||||||
} else {
|
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
|
||||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if ("CPU".equalsIgnoreCase(backend)){
|
} else if ("CPU".equalsIgnoreCase(backend)){
|
||||||
helper = new MKLDNNBatchNormHelper(dataType);
|
helper = new MKLDNNBatchNormHelper(dataType);
|
||||||
log.trace("Created MKLDNNBatchNormHelper, layer {}", layerConf().getLayerName());
|
log.trace("Created MKLDNNBatchNormHelper, layer {}", layerConf().getLayerName());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (helper != null && !helper.checkSupported(layerConf().getEps(), layerConf().isLockGammaBeta())) {
|
if (helper != null && !helper.checkSupported(layerConf().getEps(), layerConf().isLockGammaBeta())) {
|
||||||
log.debug("Removed helper {} as not supported with epsilon {}, lockGammaBeta={}", helper.getClass(), layerConf().getEps(), layerConf().isLockGammaBeta());
|
log.debug("Removed helper {} as not supported with epsilon {}, lockGammaBeta={}", helper.getClass(), layerConf().getEps(), layerConf().isLockGammaBeta());
|
||||||
helper = null;
|
helper = null;
|
||||||
|
|
|
@ -16,7 +16,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.normalization;
|
package org.deeplearning4j.nn.layers.normalization;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -36,9 +38,6 @@ import org.nd4j.common.primitives.Pair;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.nd4j.common.primitives.Triple;
|
import org.nd4j.common.primitives.Triple;
|
||||||
import org.nd4j.common.util.OneTimeLogger;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||||
|
|
||||||
|
@ -65,10 +64,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||||
* <p>
|
* <p>
|
||||||
* Created by nyghtowl on 10/29/15.
|
* Created by nyghtowl on 10/29/15.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class LocalResponseNormalization
|
public class LocalResponseNormalization
|
||||||
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
|
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
|
||||||
protected static final Logger log =
|
|
||||||
LoggerFactory.getLogger(org.deeplearning4j.nn.conf.layers.LocalResponseNormalization.class);
|
|
||||||
|
|
||||||
protected LocalResponseNormalizationHelper helper = null;
|
protected LocalResponseNormalizationHelper helper = null;
|
||||||
protected int helperCountFail = 0;
|
protected int helperCountFail = 0;
|
||||||
|
@ -86,19 +84,11 @@ public class LocalResponseNormalization
|
||||||
void initializeHelper() {
|
void initializeHelper() {
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
if("CUDA".equalsIgnoreCase(backend)) {
|
if("CUDA".equalsIgnoreCase(backend)) {
|
||||||
try {
|
helper = DL4JClassLoading.createNewInstance(
|
||||||
helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper")
|
"org.deeplearning4j.cuda.normalization.CudnnLocalResponseNormalizationHelper",
|
||||||
.asSubclass(LocalResponseNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
LocalResponseNormalizationHelper.class,
|
||||||
|
dataType);
|
||||||
log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
|
log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
|
||||||
} catch (Throwable t) {
|
|
||||||
if (!(t instanceof ClassNotFoundException)) {
|
|
||||||
log.warn("Could not initialize CudnnLocalResponseNormalizationHelper", t);
|
|
||||||
} else {
|
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
|
||||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
//2019-03-09 AB - MKL-DNN helper disabled: https://github.com/deeplearning4j/deeplearning4j/issues/7272
|
//2019-03-09 AB - MKL-DNN helper disabled: https://github.com/deeplearning4j/deeplearning4j/issues/7272
|
||||||
// else if("CPU".equalsIgnoreCase(backend)){
|
// else if("CPU".equalsIgnoreCase(backend)){
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.layers.recurrent;
|
package org.deeplearning4j.nn.layers.recurrent;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -58,22 +59,14 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
||||||
void initializeHelper() {
|
void initializeHelper() {
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
if("CUDA".equalsIgnoreCase(backend)) {
|
if("CUDA".equalsIgnoreCase(backend)) {
|
||||||
try {
|
helper = DL4JClassLoading.createNewInstance(
|
||||||
helper = Class.forName("org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper")
|
"org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper",
|
||||||
.asSubclass(LSTMHelper.class).getConstructor(DataType.class).newInstance(dataType);
|
LSTMHelper.class,
|
||||||
|
dataType);
|
||||||
log.debug("CudnnLSTMHelper successfully initialized");
|
log.debug("CudnnLSTMHelper successfully initialized");
|
||||||
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
|
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
|
||||||
helper = null;
|
helper = null;
|
||||||
}
|
}
|
||||||
} catch (Throwable t) {
|
|
||||||
if (!(t instanceof ClassNotFoundException)) {
|
|
||||||
log.warn("Could not initialize CudnnLSTMHelper", t);
|
|
||||||
} else {
|
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
|
||||||
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
|
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
|
||||||
|
|
|
@ -21,6 +21,7 @@ import com.beust.jcommander.Parameter;
|
||||||
import com.beust.jcommander.ParameterException;
|
import com.beust.jcommander.ParameterException;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
||||||
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
|
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
|
@ -126,48 +127,44 @@ public class ParallelWrapperMain {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
if (dataSetIteratorFactoryClazz != null) {
|
if (dataSetIteratorFactoryClazz != null) {
|
||||||
DataSetIteratorProviderFactory dataSetIteratorProviderFactory =
|
DataSetIteratorProviderFactory dataSetIteratorProviderFactory = DL4JClassLoading
|
||||||
(DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
|
.createNewInstance(dataSetIteratorFactoryClazz);
|
||||||
|
|
||||||
DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
|
DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
|
||||||
if (uiUrl != null) {
|
if (uiUrl != null) {
|
||||||
// it's important that the UI can report results from parallel training
|
// it's important that the UI can report results from parallel training
|
||||||
// there's potential for StatsListener to fail if certain properties aren't set in the model
|
// there's potential for StatsListener to fail if certain properties aren't set in the model
|
||||||
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
|
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
|
||||||
TrainingListener l;
|
TrainingListener trainingListener = DL4JClassLoading.createNewInstance(
|
||||||
try {
|
"org.deeplearning4j.ui.model.stats.StatsListener",
|
||||||
l = (TrainingListener) Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class)
|
StatsStorageRouter.class,
|
||||||
.newInstance(new Object[]{null});
|
new Class[] { StatsStorageRouter.class },
|
||||||
} catch (ClassNotFoundException e){
|
new Object[] { null });
|
||||||
throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
|
wrapper.setListeners(remoteUIRouter, trainingListener);
|
||||||
}
|
|
||||||
wrapper.setListeners(remoteUIRouter, l);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
wrapper.fit(dataSetIterator);
|
wrapper.fit(dataSetIterator);
|
||||||
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
|
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
|
||||||
|
|
||||||
|
|
||||||
} else if (multiDataSetIteratorFactoryClazz != null) {
|
} else if (multiDataSetIteratorFactoryClazz != null) {
|
||||||
MultiDataSetProviderFactory multiDataSetProviderFactory =
|
MultiDataSetProviderFactory multiDataSetProviderFactory = DL4JClassLoading
|
||||||
(MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
|
.createNewInstance(multiDataSetIteratorFactoryClazz);
|
||||||
|
|
||||||
MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
|
MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
|
||||||
if (uiUrl != null) {
|
if (uiUrl != null) {
|
||||||
// it's important that the UI can report results from parallel training
|
// it's important that the UI can report results from parallel training
|
||||||
// there's potential for StatsListener to fail if certain properties aren't set in the model
|
// there's potential for StatsListener to fail if certain properties aren't set in the model
|
||||||
remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
|
remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
|
||||||
TrainingListener l;
|
TrainingListener trainingListener = DL4JClassLoading
|
||||||
try {
|
.createNewInstance(
|
||||||
l = (TrainingListener) Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class)
|
"org.deeplearning4j.ui.model.stats.StatsListener",
|
||||||
.newInstance(new Object[]{null});
|
TrainingListener.class,
|
||||||
} catch (ClassNotFoundException e){
|
new Class[]{ StatsStorageRouter.class },
|
||||||
throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
|
new Object[]{ null });
|
||||||
}
|
wrapper.setListeners(remoteUIRouter, trainingListener);
|
||||||
wrapper.setListeners(remoteUIRouter, l);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
wrapper.fit(iterator);
|
wrapper.fit(iterator);
|
||||||
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
|
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
|
throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.broadcast.Broadcast;
|
import org.apache.spark.broadcast.Broadcast;
|
||||||
import org.apache.spark.storage.StorageLevel;
|
import org.apache.spark.storage.StorageLevel;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||||
|
@ -161,14 +162,9 @@ public class SparkSequenceVectors<T extends SequenceElement> extends SequenceVec
|
||||||
validateConfiguration();
|
validateConfiguration();
|
||||||
|
|
||||||
if (ela == null) {
|
if (ela == null) {
|
||||||
try {
|
String className = configuration.getElementsLearningAlgorithm();
|
||||||
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
|
ela = DL4JClassLoading.createNewInstance(className);
|
||||||
.newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (workers > 1) {
|
if (workers > 1) {
|
||||||
log.info("Repartitioning corpus to {} parts...", workers);
|
log.info("Repartitioning corpus to {} parts...", workers);
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.apache.spark.broadcast.Broadcast;
|
import org.apache.spark.broadcast.Broadcast;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
|
@ -42,24 +43,17 @@ public abstract class BaseTokenizerFunction implements Serializable {
|
||||||
String tpClassName = this.configurationBroadcast.getValue().getTokenPreProcessor();
|
String tpClassName = this.configurationBroadcast.getValue().getTokenPreProcessor();
|
||||||
|
|
||||||
if (tfClassName != null && !tfClassName.isEmpty()) {
|
if (tfClassName != null && !tfClassName.isEmpty()) {
|
||||||
try {
|
tokenizerFactory = DL4JClassLoading.createNewInstance(tfClassName);
|
||||||
tokenizerFactory = (TokenizerFactory) Class.forName(tfClassName).newInstance();
|
|
||||||
|
|
||||||
if (tpClassName != null && !tpClassName.isEmpty()) {
|
if (tpClassName != null && !tpClassName.isEmpty()) {
|
||||||
try {
|
tokenPreprocessor = DL4JClassLoading.createNewInstance(tpClassName);
|
||||||
tokenPreprocessor = (TokenPreProcess) Class.forName(tpClassName).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException("Unable to instantiate TokenPreProcessor.", e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tokenPreprocessor != null) {
|
if (tokenPreprocessor != null) {
|
||||||
tokenizerFactory.setTokenPreProcessor(tokenPreprocessor);
|
tokenizerFactory.setTokenPreProcessor(tokenPreprocessor);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} else {
|
||||||
throw new RuntimeException("Unable to instantiate TokenizerFactory.", e);
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
throw new RuntimeException("TokenizerFactory wasn't defined.");
|
throw new RuntimeException("TokenizerFactory wasn't defined.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.Accumulator;
|
import org.apache.spark.Accumulator;
|
||||||
import org.apache.spark.api.java.function.Function;
|
import org.apache.spark.api.java.function.Function;
|
||||||
import org.apache.spark.broadcast.Broadcast;
|
import org.apache.spark.broadcast.Broadcast;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
|
@ -65,13 +66,8 @@ public class CountFunction<T extends SequenceElement> implements Function<Sequen
|
||||||
long seqLen = 0;
|
long seqLen = 0;
|
||||||
|
|
||||||
if (ela == null) {
|
if (ela == null) {
|
||||||
try {
|
String elementsLearningAlgorithm = vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm();
|
||||||
ela = (SparkElementsLearningAlgorithm) Class
|
ela = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||||
.forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm())
|
|
||||||
.newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
driver = ela.getTrainingDriver();
|
driver = ela.getTrainingDriver();
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.apache.spark.api.java.function.VoidFunction;
|
import org.apache.spark.api.java.function.VoidFunction;
|
||||||
import org.apache.spark.broadcast.Broadcast;
|
import org.apache.spark.broadcast.Broadcast;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
|
@ -74,19 +75,15 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
|
||||||
if (vectorsConfiguration == null)
|
if (vectorsConfiguration == null)
|
||||||
vectorsConfiguration = configurationBroadcast.getValue();
|
vectorsConfiguration = configurationBroadcast.getValue();
|
||||||
|
|
||||||
|
String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
|
||||||
if (paramServer == null) {
|
if (paramServer == null) {
|
||||||
paramServer = VoidParameterServer.getInstance();
|
paramServer = VoidParameterServer.getInstance();
|
||||||
|
|
||||||
if (elementsLearningAlgorithm == null) {
|
if (this.elementsLearningAlgorithm == null) {
|
||||||
try {
|
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
|
||||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
driver = elementsLearningAlgorithm.getTrainingDriver();
|
driver = this.elementsLearningAlgorithm.getTrainingDriver();
|
||||||
|
|
||||||
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
|
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
|
||||||
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
|
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
|
||||||
|
@ -95,33 +92,24 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
|
||||||
if (shallowVocabCache == null)
|
if (shallowVocabCache == null)
|
||||||
shallowVocabCache = vocabCacheBroadcast.getValue();
|
shallowVocabCache = vocabCacheBroadcast.getValue();
|
||||||
|
|
||||||
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
|
if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
|
||||||
// TODO: do ELA initialization
|
// TODO: do ELA initialization
|
||||||
try {
|
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
|
||||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (elementsLearningAlgorithm != null)
|
if (this.elementsLearningAlgorithm != null)
|
||||||
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||||
|
|
||||||
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
|
String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
|
||||||
|
if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
|
||||||
// TODO: do SLA initialization
|
// TODO: do SLA initialization
|
||||||
try {
|
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
|
||||||
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
|
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||||
.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
|
|
||||||
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
}
|
||||||
}
|
if (this.sequenceLearningAlgorithm != null)
|
||||||
if (sequenceLearningAlgorithm != null)
|
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||||
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
|
||||||
|
|
||||||
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
|
if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
|
||||||
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
|
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,7 +130,7 @@ public class PartitionTrainingFunction<T extends SequenceElement> implements Voi
|
||||||
}
|
}
|
||||||
|
|
||||||
// do the same with labels, transfer them, if any
|
// do the same with labels, transfer them, if any
|
||||||
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
||||||
for (T label : sequence.getSequenceLabels()) {
|
for (T label : sequence.getSequenceLabels()) {
|
||||||
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
|
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.api.java.function.VoidFunction;
|
import org.apache.spark.api.java.function.VoidFunction;
|
||||||
import org.apache.spark.broadcast.Broadcast;
|
import org.apache.spark.broadcast.Broadcast;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
|
@ -73,19 +74,15 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
||||||
if (vectorsConfiguration == null)
|
if (vectorsConfiguration == null)
|
||||||
vectorsConfiguration = configurationBroadcast.getValue();
|
vectorsConfiguration = configurationBroadcast.getValue();
|
||||||
|
|
||||||
|
String elementsLearningAlgorithm = vectorsConfiguration.getElementsLearningAlgorithm();
|
||||||
if (paramServer == null) {
|
if (paramServer == null) {
|
||||||
paramServer = VoidParameterServer.getInstance();
|
paramServer = VoidParameterServer.getInstance();
|
||||||
|
|
||||||
if (elementsLearningAlgorithm == null) {
|
if (this.elementsLearningAlgorithm == null) {
|
||||||
try {
|
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
|
||||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
driver = elementsLearningAlgorithm.getTrainingDriver();
|
driver = this.elementsLearningAlgorithm.getTrainingDriver();
|
||||||
|
|
||||||
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
|
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
|
||||||
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
|
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
|
||||||
|
@ -98,33 +95,23 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
||||||
shallowVocabCache = vocabCacheBroadcast.getValue();
|
shallowVocabCache = vocabCacheBroadcast.getValue();
|
||||||
|
|
||||||
|
|
||||||
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
|
if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
|
||||||
// TODO: do ELA initialization
|
// TODO: do ELA initialization
|
||||||
try {
|
this.elementsLearningAlgorithm = DL4JClassLoading.createNewInstance(elementsLearningAlgorithm);
|
||||||
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class
|
this.elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||||
.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
|
|
||||||
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
|
String sequenceLearningAlgorithm = vectorsConfiguration.getSequenceLearningAlgorithm();
|
||||||
|
if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
|
||||||
// TODO: do SLA initialization
|
// TODO: do SLA initialization
|
||||||
try {
|
this.sequenceLearningAlgorithm = DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm);
|
||||||
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class
|
this.sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
||||||
.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
|
|
||||||
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
|
if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
|
||||||
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
|
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
at this moment we should have everything ready for actual initialization
|
at this moment we should have everything ready for actual initialization
|
||||||
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
|
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
|
||||||
|
@ -139,7 +126,7 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
||||||
}
|
}
|
||||||
|
|
||||||
// do the same with labels, transfer them, if any
|
// do the same with labels, transfer them, if any
|
||||||
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
if (this.sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
|
||||||
for (T label : sequence.getSequenceLabels()) {
|
for (T label : sequence.getSequenceLabels()) {
|
||||||
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
|
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
|
||||||
|
|
||||||
|
@ -157,7 +144,7 @@ public class TrainingFunction<T extends SequenceElement> implements VoidFunction
|
||||||
// FIXME: temporary hook
|
// FIXME: temporary hook
|
||||||
if (sequence.size() > 0)
|
if (sequence.size() > 0)
|
||||||
paramServer.execDistributed(
|
paramServer.execDistributed(
|
||||||
elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
|
this.elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
|
||||||
else
|
else
|
||||||
log.warn("Skipping empty sequence...");
|
log.warn("Skipping empty sequence...");
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.spark.models.sequencevectors.functions;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||||
import org.apache.spark.broadcast.Broadcast;
|
import org.apache.spark.broadcast.Broadcast;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
|
@ -56,12 +57,8 @@ public class VocabRddFunctionFlat<T extends SequenceElement> implements FlatMapF
|
||||||
configuration = vectorsConfigurationBroadcast.getValue();
|
configuration = vectorsConfigurationBroadcast.getValue();
|
||||||
|
|
||||||
if (ela == null) {
|
if (ela == null) {
|
||||||
try {
|
String className = configuration.getElementsLearningAlgorithm();
|
||||||
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm())
|
ela = DL4JClassLoading.createNewInstance(className);
|
||||||
.newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
driver = ela.getTrainingDriver();
|
driver = ela.getTrainingDriver();
|
||||||
|
|
||||||
|
|
|
@ -17,12 +17,14 @@
|
||||||
package org.deeplearning4j.spark.text.functions;
|
package org.deeplearning4j.spark.text.functions;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.spark.api.java.function.Function;
|
import org.apache.spark.api.java.function.Function;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -44,35 +46,33 @@ public class TokenizerFunction implements Function<String, List<String>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<String> call(String v1) throws Exception {
|
public List<String> call(String str) {
|
||||||
if (tokenizerFactory == null)
|
if (tokenizerFactory == null) {
|
||||||
tokenizerFactory = getTokenizerFactory();
|
tokenizerFactory = getTokenizerFactory();
|
||||||
if (v1.isEmpty())
|
}
|
||||||
return Arrays.asList("");
|
|
||||||
return tokenizerFactory.create(v1).getTokens();
|
if (str.isEmpty()) {
|
||||||
|
return Collections.singletonList("");
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenizerFactory.create(str).getTokens();
|
||||||
}
|
}
|
||||||
|
|
||||||
private TokenizerFactory getTokenizerFactory() {
|
private TokenizerFactory getTokenizerFactory() {
|
||||||
try {
|
|
||||||
TokenPreProcess tokenPreProcessInst = null;
|
TokenPreProcess tokenPreProcessInst = null;
|
||||||
// token preprocess CAN be undefined
|
|
||||||
if (tokenizerPreprocessorClazz != null && !tokenizerPreprocessorClazz.isEmpty()) {
|
if (StringUtils.isNotEmpty(tokenizerPreprocessorClazz)) {
|
||||||
Class<? extends TokenPreProcess> clazz =
|
tokenPreProcessInst = DL4JClassLoading.createNewInstance(tokenizerPreprocessorClazz);
|
||||||
(Class<? extends TokenPreProcess>) Class.forName(tokenizerPreprocessorClazz);
|
|
||||||
tokenPreProcessInst = clazz.newInstance();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Class<? extends TokenizerFactory> clazz2 =
|
tokenizerFactory = DL4JClassLoading.createNewInstance(tokenizerFactoryClazz);
|
||||||
(Class<? extends TokenizerFactory>) Class.forName(tokenizerFactoryClazz);
|
|
||||||
tokenizerFactory = clazz2.newInstance();
|
|
||||||
if (tokenPreProcessInst != null)
|
if (tokenPreProcessInst != null)
|
||||||
tokenizerFactory.setTokenPreProcessor(tokenPreProcessInst);
|
tokenizerFactory.setTokenPreProcessor(tokenPreProcessInst);
|
||||||
if (nGrams > 1) {
|
if (nGrams > 1) {
|
||||||
tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
|
tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
return tokenizerFactory;
|
return tokenizerFactory;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.time;
|
package org.deeplearning4j.spark.time;
|
||||||
|
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.common.config.DL4JSystemProperties;
|
import org.deeplearning4j.common.config.DL4JSystemProperties;
|
||||||
|
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
@ -62,9 +63,9 @@ public class TimeSourceProvider {
|
||||||
*/
|
*/
|
||||||
public static TimeSource getInstance(String className) {
|
public static TimeSource getInstance(String className) {
|
||||||
try {
|
try {
|
||||||
Class<?> c = Class.forName(className);
|
Class<?> clazz = DL4JClassLoading.loadClassByName(className);
|
||||||
Method m = c.getMethod("getInstance");
|
Method getInstance = clazz.getMethod("getInstance");
|
||||||
return (TimeSource) m.invoke(null);
|
return (TimeSource) getInstance.invoke(null);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException("Error getting TimeSource instance for class \"" + className + "\"", e);
|
throw new RuntimeException("Error getting TimeSource instance for class \"" + className + "\"", e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.ui.model.stats;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
||||||
import org.deeplearning4j.core.storage.StorageMetaData;
|
import org.deeplearning4j.core.storage.StorageMetaData;
|
||||||
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
|
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
|
||||||
|
@ -696,11 +697,14 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
|
||||||
return devPointers.get(device);
|
return devPointers.get(device);
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
Class<?> c = Class.forName("org.nd4j.jita.allocator.pointers.CudaPointer");
|
Pointer pointer = DL4JClassLoading.createNewInstance(
|
||||||
Constructor<?> constructor = c.getConstructor(long.class);
|
"org.nd4j.jita.allocator.pointers.CudaPointer",
|
||||||
Pointer p = (Pointer) constructor.newInstance((long) device);
|
Pointer.class,
|
||||||
devPointers.put(device, p);
|
new Class[] { long.class },
|
||||||
return p;
|
(long) device);
|
||||||
|
|
||||||
|
devPointers.put(device, pointer);
|
||||||
|
return pointer;
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
devPointers.put(device, null); //Stops attempting the failure again later...
|
devPointers.put(device, null); //Stops attempting the failure again later...
|
||||||
return null;
|
return null;
|
||||||
|
@ -711,9 +715,9 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
|
||||||
ModelInfo modelInfo = getModelInfo(model);
|
ModelInfo modelInfo = getModelInfo(model);
|
||||||
int examplesThisMinibatch = 0;
|
int examplesThisMinibatch = 0;
|
||||||
if (model instanceof MultiLayerNetwork) {
|
if (model instanceof MultiLayerNetwork) {
|
||||||
examplesThisMinibatch = ((MultiLayerNetwork) model).batchSize();
|
examplesThisMinibatch = model.batchSize();
|
||||||
} else if (model instanceof ComputationGraph) {
|
} else if (model instanceof ComputationGraph) {
|
||||||
examplesThisMinibatch = ((ComputationGraph) model).batchSize();
|
examplesThisMinibatch = model.batchSize();
|
||||||
} else if (model instanceof Layer) {
|
} else if (model instanceof Layer) {
|
||||||
examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
|
examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.ui.model.storage.mapdb;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.core.storage.*;
|
import org.deeplearning4j.core.storage.*;
|
||||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||||
|
@ -318,26 +319,18 @@ public class MapDBStatsStorage extends BaseCollectionStatsStorage {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public T deserialize(@NonNull DataInput2 input, int available) throws IOException {
|
public T deserialize(@NonNull DataInput2 input, int available) throws IOException {
|
||||||
int classIdx = input.readInt();
|
int classIdx = input.readInt();
|
||||||
String className = getClassForInt(classIdx);
|
String className = getClassForInt(classIdx);
|
||||||
Class<?> clazz;
|
|
||||||
try {
|
Persistable persistable = DL4JClassLoading.createNewInstance(className);
|
||||||
clazz = Class.forName(className);
|
|
||||||
} catch (ClassNotFoundException e) {
|
|
||||||
throw new RuntimeException(e); //Shouldn't normally happen...
|
|
||||||
}
|
|
||||||
Persistable p;
|
|
||||||
try {
|
|
||||||
p = (Persistable) clazz.newInstance();
|
|
||||||
} catch (InstantiationException | IllegalAccessException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
int remainingLength = available - 4; // -4 for int class index
|
int remainingLength = available - 4; // -4 for int class index
|
||||||
byte[] temp = new byte[remainingLength];
|
byte[] temp = new byte[remainingLength];
|
||||||
input.readFully(temp);
|
input.readFully(temp);
|
||||||
p.decode(temp);
|
persistable.decode(temp);
|
||||||
return (T) p;
|
return (T) persistable;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -32,32 +32,42 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
|
import org.deeplearning4j.common.config.DL4JSystemProperties;
|
||||||
|
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||||
import org.deeplearning4j.core.storage.StatsStorage;
|
import org.deeplearning4j.core.storage.StatsStorage;
|
||||||
import org.deeplearning4j.core.storage.StatsStorageEvent;
|
import org.deeplearning4j.core.storage.StatsStorageEvent;
|
||||||
import org.deeplearning4j.core.storage.StatsStorageListener;
|
import org.deeplearning4j.core.storage.StatsStorageListener;
|
||||||
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
||||||
import org.deeplearning4j.common.config.DL4JSystemProperties;
|
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.ui.api.Route;
|
import org.deeplearning4j.ui.api.Route;
|
||||||
import org.deeplearning4j.ui.api.UIModule;
|
import org.deeplearning4j.ui.api.UIModule;
|
||||||
import org.deeplearning4j.ui.api.UIServer;
|
import org.deeplearning4j.ui.api.UIServer;
|
||||||
import org.deeplearning4j.ui.i18n.I18NProvider;
|
import org.deeplearning4j.ui.i18n.I18NProvider;
|
||||||
|
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||||
|
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||||
|
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
|
||||||
import org.deeplearning4j.ui.module.SameDiffModule;
|
import org.deeplearning4j.ui.module.SameDiffModule;
|
||||||
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
|
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
|
||||||
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
|
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
|
||||||
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
|
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
|
||||||
import org.deeplearning4j.ui.module.train.TrainModule;
|
import org.deeplearning4j.ui.module.train.TrainModule;
|
||||||
import org.deeplearning4j.ui.module.tsne.TsneModule;
|
import org.deeplearning4j.ui.module.tsne.TsneModule;
|
||||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
|
||||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
|
||||||
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
|
|
||||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
|
||||||
import org.nd4j.common.function.Function;
|
import org.nd4j.common.function.Function;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
import java.util.concurrent.*;
|
import java.util.Collections;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.ServiceLoader;
|
||||||
|
import java.util.concurrent.BlockingQueue;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@ -402,8 +412,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void modulesViaServiceLoader(List<UIModule> uiModules) {
|
private void modulesViaServiceLoader(List<UIModule> uiModules) {
|
||||||
|
ServiceLoader<UIModule> sl = DL4JClassLoading.loadService(UIModule.class);
|
||||||
ServiceLoader<UIModule> sl = ServiceLoader.load(UIModule.class);
|
|
||||||
Iterator<UIModule> iter = sl.iterator();
|
Iterator<UIModule> iter = sl.iterator();
|
||||||
|
|
||||||
if (!iter.hasNext()) {
|
if (!iter.hasNext()) {
|
||||||
|
@ -411,19 +420,19 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
while (iter.hasNext()) {
|
while (iter.hasNext()) {
|
||||||
UIModule m = iter.next();
|
UIModule module = iter.next();
|
||||||
Class<?> c = m.getClass();
|
Class<?> moduleClass = module.getClass();
|
||||||
boolean foundExisting = false;
|
boolean foundExisting = false;
|
||||||
for (UIModule mExisting : uiModules) {
|
for (UIModule mExisting : uiModules) {
|
||||||
if (mExisting.getClass() == c) {
|
if (mExisting.getClass() == moduleClass) {
|
||||||
foundExisting = true;
|
foundExisting = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!foundExisting) {
|
if (!foundExisting) {
|
||||||
log.debug("Loaded UI module via service loader: {}", m.getClass());
|
log.debug("Loaded UI module via service loader: {}", module.getClass());
|
||||||
uiModules.add(m);
|
uiModules.add(module);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.ui.i18n;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.ui.api.I18N;
|
import org.deeplearning4j.ui.api.I18N;
|
||||||
import org.deeplearning4j.ui.api.UIModule;
|
import org.deeplearning4j.ui.api.UIModule;
|
||||||
|
|
||||||
|
@ -100,13 +101,13 @@ public class DefaultI18N implements I18N {
|
||||||
}
|
}
|
||||||
|
|
||||||
private synchronized void loadLanguages(){
|
private synchronized void loadLanguages(){
|
||||||
ServiceLoader<UIModule> sl = ServiceLoader.load(UIModule.class);
|
ServiceLoader<UIModule> loadedModules = DL4JClassLoading.loadService(UIModule.class);
|
||||||
|
|
||||||
for(UIModule m : sl){
|
for (UIModule module : loadedModules){
|
||||||
List<I18NResource> resources = m.getInternationalizationResources();
|
List<I18NResource> resources = module.getInternationalizationResources();
|
||||||
for(I18NResource r : resources){
|
for(I18NResource resource : resources){
|
||||||
try {
|
try {
|
||||||
String path = r.getResource();
|
String path = resource.getResource();
|
||||||
int idxLast = path.lastIndexOf('.');
|
int idxLast = path.lastIndexOf('.');
|
||||||
if (idxLast < 0) {
|
if (idxLast < 0) {
|
||||||
log.warn("Skipping language resource file: cannot infer language: {}", path);
|
log.warn("Skipping language resource file: cannot infer language: {}", path);
|
||||||
|
@ -116,9 +117,9 @@ public class DefaultI18N implements I18N {
|
||||||
String langCode = path.substring(idxLast + 1).toLowerCase();
|
String langCode = path.substring(idxLast + 1).toLowerCase();
|
||||||
Map<String, String> map = messagesByLanguage.computeIfAbsent(langCode, k -> new HashMap<>());
|
Map<String, String> map = messagesByLanguage.computeIfAbsent(langCode, k -> new HashMap<>());
|
||||||
|
|
||||||
parseFile(r, map);
|
parseFile(resource, map);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.warn("Error parsing UI I18N content file; skipping: {}", r.getResource(), t);
|
log.warn("Error parsing UI I18N content file; skipping: {}", resource.getResource(), t);
|
||||||
languageLoadingException = t;
|
languageLoadingException = t;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
import io.vertx.core.json.JsonObject;
|
import io.vertx.core.json.JsonObject;
|
||||||
import io.vertx.ext.web.RoutingContext;
|
import io.vertx.ext.web.RoutingContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.core.storage.*;
|
import org.deeplearning4j.core.storage.*;
|
||||||
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
|
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
|
||||||
import org.deeplearning4j.ui.api.HttpMethod;
|
import org.deeplearning4j.ui.api.HttpMethod;
|
||||||
|
@ -154,9 +155,12 @@ public class RemoteReceiverModule implements UIModule {
|
||||||
private StorageMetaData getMetaData(String dataClass, String content) {
|
private StorageMetaData getMetaData(String dataClass, String content) {
|
||||||
StorageMetaData meta;
|
StorageMetaData meta;
|
||||||
try {
|
try {
|
||||||
Class<?> c = Class.forName(dataClass);
|
Class<?> clazz = DL4JClassLoading.loadClassByName(dataClass);
|
||||||
if (StorageMetaData.class.isAssignableFrom(c)) {
|
if (StorageMetaData.class.isAssignableFrom(clazz)) {
|
||||||
meta = (StorageMetaData) c.newInstance();
|
meta = clazz
|
||||||
|
.asSubclass(StorageMetaData.class)
|
||||||
|
.getDeclaredConstructor()
|
||||||
|
.newInstance();
|
||||||
} else {
|
} else {
|
||||||
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
|
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
|
||||||
StorageMetaData.class.getName());
|
StorageMetaData.class.getName());
|
||||||
|
@ -179,11 +183,14 @@ public class RemoteReceiverModule implements UIModule {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Persistable getPersistable(String dataClass, String content) {
|
private Persistable getPersistable(String dataClass, String content) {
|
||||||
Persistable p;
|
Persistable persistable;
|
||||||
try {
|
try {
|
||||||
Class<?> c = Class.forName(dataClass);
|
Class<?> clazz = DL4JClassLoading.loadClassByName(dataClass);
|
||||||
if (Persistable.class.isAssignableFrom(c)) {
|
if (Persistable.class.isAssignableFrom(clazz)) {
|
||||||
p = (Persistable) c.newInstance();
|
persistable = clazz
|
||||||
|
.asSubclass(Persistable.class)
|
||||||
|
.getDeclaredConstructor()
|
||||||
|
.newInstance();
|
||||||
} else {
|
} else {
|
||||||
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
|
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
|
||||||
Persistable.class.getName());
|
Persistable.class.getName());
|
||||||
|
@ -196,12 +203,12 @@ public class RemoteReceiverModule implements UIModule {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
byte[] bytes = DatatypeConverter.parseBase64Binary(content);
|
byte[] bytes = DatatypeConverter.parseBase64Binary(content);
|
||||||
p.decode(bytes);
|
persistable.decode(bytes);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("Skipping invalid remote data: exception encountered when deserializing data", e);
|
log.warn("Skipping invalid remote data: exception encountered when deserializing data", e);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return p;
|
return persistable;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -22,6 +22,7 @@ import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
|
@ -127,7 +128,7 @@ public class IntegrationTestRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ClassPath.ClassInfo c : info) {
|
for (ClassPath.ClassInfo c : info) {
|
||||||
Class<?> clazz = Class.forName(c.getName());
|
Class<?> clazz = DL4JClassLoading.loadClassByName(c.getName());
|
||||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
|
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,19 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) Eclipse Deeplearning4j Contributors 2020
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.nd4j.common.config;
|
package org.nd4j.common.config;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -6,6 +22,22 @@ import java.util.ServiceLoader;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Global context for class-loading in ND4J.
|
* Global context for class-loading in ND4J.
|
||||||
|
* <p>Use {@code ND4JClassLoading} to define classloader for ND4J only! To define classloader used by
|
||||||
|
* {@code Deeplearning4j} use class {@link org.deeplearning4j.common.config.DL4JClassLoading}.
|
||||||
|
*
|
||||||
|
* <p>Usage:
|
||||||
|
* <pre>{@code
|
||||||
|
* public class Application {
|
||||||
|
* static {
|
||||||
|
* ND4JClassLoading.setNd4jClassloaderFromClass(Application.class);
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* public static void main(String[] args) {
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* }</code>
|
||||||
|
*
|
||||||
|
* @see org.deeplearning4j.common.config.DL4JClassLoading
|
||||||
*
|
*
|
||||||
* @author Alexei KLENIN
|
* @author Alexei KLENIN
|
||||||
*/
|
*/
|
||||||
|
|
Loading…
Reference in New Issue