commit
da4bf0209b
|
@ -9,7 +9,7 @@ Deeplearning4j's [open issues are here](https://github.com/eclipse/deeplearning4
|
|||
|
||||
Note that you will need to [build dl4j from source](https://deeplearning4j.org/docs/latest/deeplearning4j-build-from-source)
|
||||
|
||||
For some tips on contributing to open source, this [post is helpful](http://blog.smartbear.com/programming/14-ways-to-contribute-to-open-source-without-being-a-programming-genius-or-a-rock-star/).
|
||||
For some tips on contributing to open source, this [post is helpful](https://smartbear.com/blog/test-and-monitor/14-ways-to-contribute-to-open-source-without-being/).
|
||||
|
||||
## Contributions
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@
|
|||
<outputDirectory>examples</outputDirectory>
|
||||
<!--
|
||||
<lineEnding>unix</lineEnding>
|
||||
http://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
|
||||
https://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
|
||||
-->
|
||||
</fileSet>
|
||||
|
||||
|
|
|
@ -52,11 +52,6 @@
|
|||
<artifactId>joda-time</artifactId>
|
||||
<version>${jodatime.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.yaml</groupId>
|
||||
<artifactId>snakeyaml</artifactId>
|
||||
<version>${snakeyaml.version}</version>
|
||||
</dependency>
|
||||
<!-- ND4J Shaded Jackson Dependencies -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -29,21 +29,11 @@
|
|||
<name>datavec-arrow</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-arrow</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.carrotsearch</groupId>
|
||||
<artifactId>hppc</artifactId>
|
||||
<version>${hppc.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-vector</artifactId>
|
||||
|
|
|
@ -44,26 +44,6 @@
|
|||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-logging</groupId>
|
||||
<artifactId>commons-logging</artifactId>
|
||||
<version>${commons-logging.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-core</artifactId>
|
||||
<version>${spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-context</artifactId>
|
||||
<version>${spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-beans</artifactId>
|
||||
<version>${spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.cleartk</groupId>
|
||||
<artifactId>cleartk-snowball</artifactId>
|
||||
|
|
|
@ -31,36 +31,6 @@
|
|||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-yaml</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-xml</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-joda</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.maxmind.geoip2</groupId>
|
||||
<artifactId>geoip2</artifactId>
|
||||
|
|
|
@ -35,41 +35,11 @@
|
|||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-core</artifactId>
|
||||
<version>${jaxb.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-impl</artifactId>
|
||||
<version>${jaxb.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty</artifactId>
|
||||
<version>${netty.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-compress</artifactId>
|
||||
<version>${commons-compress.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.zookeeper</groupId>
|
||||
<artifactId>zookeeper</artifactId>
|
||||
<version>${zookeeper.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>log4j</groupId>
|
||||
<artifactId>log4j</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-log4j12</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
|
|
|
@ -73,42 +73,7 @@
|
|||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-yaml</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-xml</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-joda</artifactId>
|
||||
<version>${geo.jackson.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-python</artifactId>
|
||||
|
|
|
@ -41,11 +41,6 @@
|
|||
<artifactId>slf4j-api</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.oshi</groupId>
|
||||
<artifactId>oshi-core</artifactId>
|
||||
<version>${oshi.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
|
|
|
@ -41,26 +41,6 @@
|
|||
<version>1.0.0-SNAPSHOT</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-codec</groupId>
|
||||
<artifactId>commons-codec</artifactId>
|
||||
<version>${commons-codec.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents</groupId>
|
||||
<artifactId>httpclient</artifactId>
|
||||
<version>${httpclient.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents</groupId>
|
||||
<artifactId>httpcore</artifactId>
|
||||
<version>${httpcore.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents</groupId>
|
||||
<artifactId>httpmime</artifactId>
|
||||
<version>${httpmime.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
|
|
|
@ -94,12 +94,6 @@
|
|||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.yaml</groupId>
|
||||
<artifactId>snakeyaml</artifactId>
|
||||
<version>${snakeyaml.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-java_2.11</artifactId>
|
||||
|
|
|
@ -39,11 +39,6 @@
|
|||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
|
|
|
@ -64,7 +64,7 @@ public class DL4JResources {
|
|||
/**
|
||||
* Set the base download URL for (most) DL4J datasets and models.<br>
|
||||
* This usually doesn't need to be set manually unless there is some issue with the default location
|
||||
* @param baseDownloadURL Base download URL to set. For example, http://blob.deeplearning4j.org/
|
||||
* @param baseDownloadURL Base download URL to set. For example, https://dl4jdata.blob.core.windows.net/
|
||||
*/
|
||||
public static void setBaseDownloadURL(@NonNull String baseDownloadURL){
|
||||
baseURL = baseDownloadURL;
|
||||
|
@ -79,8 +79,8 @@ public class DL4JResources {
|
|||
|
||||
/**
|
||||
* Get the URL relative to the base URL.<br>
|
||||
* For example, if baseURL is "http://blob.deeplearning4j.org/", and relativeToBase is "/datasets/iris.dat"
|
||||
* this simply returns "http://blob.deeplearning4j.org/datasets/iris.dat"
|
||||
* For example, if baseURL is "https://dl4jdata.blob.core.windows.net/", and relativeToBase is "/datasets/iris.dat"
|
||||
* this simply returns "https://dl4jdata.blob.core.windows.net/datasets/iris.dat"
|
||||
*
|
||||
* @param relativeToBase Relative URL
|
||||
* @return URL
|
||||
|
@ -92,8 +92,8 @@ public class DL4JResources {
|
|||
|
||||
/**
|
||||
* Get the URL relative to the base URL as a String.<br>
|
||||
* For example, if baseURL is "http://blob.deeplearning4j.org/", and relativeToBase is "/datasets/iris.dat"
|
||||
* this simply returns "http://blob.deeplearning4j.org/datasets/iris.dat"
|
||||
* For example, if baseURL is "https://dl4jdata.blob.core.windows.net/", and relativeToBase is "/datasets/iris.dat"
|
||||
* this simply returns "https://dl4jdata.blob.core.windows.net/datasets/iris.dat"
|
||||
*
|
||||
* @param relativeToBase Relative URL
|
||||
* @return URL
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions;
|
|||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -63,6 +64,30 @@ public class LayerHelperValidationUtil {
|
|||
private DataSetIterator data;
|
||||
}
|
||||
|
||||
public static void disableCppHelpers(){
|
||||
try {
|
||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method m = c.getMethod("getInstance");
|
||||
Object instance = m.invoke(null);
|
||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
||||
m2.invoke(instance, false);
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
}
|
||||
|
||||
public static void enableCppHelpers(){
|
||||
try{
|
||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method m = c.getMethod("getInstance");
|
||||
Object instance = m.invoke(null);
|
||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
||||
m2.invoke(instance, true);
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
}
|
||||
|
||||
public static void validateMLN(MultiLayerNetwork netOrig, TestCase t){
|
||||
assertNotNull(t.getAllowHelpersForClasses());
|
||||
assertFalse(t.getAllowHelpersForClasses().isEmpty());
|
||||
|
@ -95,7 +120,13 @@ public class LayerHelperValidationUtil {
|
|||
for (boolean train : new boolean[]{false, true}) {
|
||||
assertEquals(net1NoHelper.params(), net2With.params());
|
||||
String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
|
||||
List<INDArray> ff1 = net1NoHelper.feedForward(t.getFeatures(), train);
|
||||
List<INDArray> ff1;
|
||||
try {
|
||||
disableCppHelpers();
|
||||
ff1 = net1NoHelper.feedForward(t.getFeatures(), train);
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
List<INDArray> ff2 = net2With.feedForward(t.getFeatures(), train);
|
||||
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());
|
||||
Collections.sort(paramKeys);
|
||||
|
@ -131,7 +162,13 @@ public class LayerHelperValidationUtil {
|
|||
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
||||
}
|
||||
|
||||
INDArray out1 = net1NoHelper.output(t.getFeatures(), train);
|
||||
INDArray out1;
|
||||
try {
|
||||
disableCppHelpers();
|
||||
out1 = net1NoHelper.output(t.getFeatures(), train);
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
INDArray out2 = net2With.output(t.getFeatures(), train);
|
||||
INDArray relError = relError(out1, out2, t.getMinAbsError());
|
||||
double maxRE = relError.maxNumber().doubleValue();
|
||||
|
@ -148,7 +185,13 @@ public class LayerHelperValidationUtil {
|
|||
Preconditions.checkNotNull(t.getLabels(), "Labels are not set (null)");
|
||||
|
||||
log.info("Validation - checking scores");
|
||||
double s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels()));
|
||||
double s1;
|
||||
try {
|
||||
disableCppHelpers();
|
||||
s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels()));
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
double s2 = net2With.score(new DataSet(t.getFeatures(), t.getLabels()));
|
||||
|
||||
double re = relError(s1, s2);
|
||||
|
@ -168,7 +211,12 @@ public class LayerHelperValidationUtil {
|
|||
net2With.setInput(t.getFeatures());
|
||||
net2With.setLabels(t.getLabels());
|
||||
|
||||
try {
|
||||
disableCppHelpers();
|
||||
net1NoHelper.computeGradientAndScore();
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
net2With.computeGradientAndScore();
|
||||
|
||||
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
package org.deeplearning4j;
|
||||
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import static junit.framework.TestCase.*;
|
||||
|
||||
public class TestBatchNormBp {
|
||||
|
||||
@Test
|
||||
public void test(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
|
||||
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
|
||||
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
|
||||
double e = 1e-5;
|
||||
|
||||
INDArray dLdIn = in.ulike();
|
||||
INDArray dLdm = mean.ulike();
|
||||
INDArray dLdv = var.ulike();
|
||||
INDArray dLdg = gamma.ulike();
|
||||
INDArray dLdb = beta.ulike();
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||
.addInputs(in, mean, var, eps, gamma, beta)
|
||||
.addIntegerArguments(
|
||||
1, //Apply scale
|
||||
1, //Apply beta
|
||||
1) //Axis (NCHW)
|
||||
.addFloatingPointArguments(e)
|
||||
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
System.out.println(dLdIn);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void compareImpls() throws Exception {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
double e = 1e-3;
|
||||
|
||||
INDArray dLdIn = in.ulike();
|
||||
INDArray dLdm = mean.ulike();
|
||||
INDArray dLdv = var.ulike();
|
||||
INDArray dLdg = gamma.ulike();
|
||||
INDArray dLdb = beta.ulike();
|
||||
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||
.list()
|
||||
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||
assertNotNull(bn.getHelper());
|
||||
Field f = bn.getClass().getDeclaredField("helper");
|
||||
f.setAccessible(true);
|
||||
f.set(bn, null);
|
||||
assertNull(bn.getHelper());
|
||||
|
||||
|
||||
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
|
||||
|
||||
net.output(in, true);
|
||||
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
INDArray dldin_dl4j = p.getSecond();
|
||||
|
||||
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
|
||||
}
|
||||
|
||||
}
|
|
@ -2143,4 +2143,23 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
INDArray in = Nd4j.create(DataType.FLOAT, 1, 3, 16, 16, 16);
|
||||
INDArray out = cg.outputSingle(in);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDualEmbedding(){
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.addLayer("e1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in")
|
||||
.addLayer("e2", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in")
|
||||
.addLayer("out", new OutputLayer.Builder().nIn(10).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "e1", "e2")
|
||||
.setOutputs("out")
|
||||
.build();
|
||||
|
||||
ComputationGraph cg = new ComputationGraph(conf);
|
||||
cg.init();
|
||||
|
||||
INDArray in = Nd4j.createFromArray(3).reshape(1, 1);
|
||||
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
|
||||
cg.fit(new DataSet(in, label));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,9 +70,20 @@ public class MinimalSameDiffDense extends SameDiffLayer {
|
|||
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
String b = DefaultParamInitializer.BIAS_KEY;
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(b)){
|
||||
paramWeightInit.get(b).init(nIn, nOut, params.get(b).shape(), 'c', params.get(b));
|
||||
} else {
|
||||
params.get(DefaultParamInitializer.BIAS_KEY).assign(0);
|
||||
}
|
||||
|
||||
String w = DefaultParamInitializer.WEIGHT_KEY;
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(w)){
|
||||
paramWeightInit.get(w).init(nIn, nOut, params.get(w).shape(), 'c', params.get(w));
|
||||
} else {
|
||||
initWeights(nIn, nOut, weightInit, params.get(DefaultParamInitializer.WEIGHT_KEY));
|
||||
}
|
||||
}
|
||||
|
||||
//OPTIONAL methods:
|
||||
// public void setNIn(InputType inputType, boolean override)
|
||||
|
|
|
@ -109,17 +109,21 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
double fanIn = nIn * kernel[0] * kernel[1];
|
||||
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
|
||||
paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
|
||||
} else {
|
||||
if (ConvolutionParamInitializer.BIAS_KEY.equals(e.getKey())) {
|
||||
e.getValue().assign(0);
|
||||
} else {
|
||||
double fanIn = nIn * kernel[0] * kernel[1];
|
||||
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
|
||||
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
|
|
|
@ -88,6 +88,9 @@ public class SameDiffDense extends SameDiffLayer {
|
|||
@Override
|
||||
public void initializeParameters(Map<String,INDArray> params){
|
||||
for(Map.Entry<String,INDArray> e : params.entrySet()){
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
|
||||
paramWeightInit.get(e.getKey()).init(nIn, nOut, e.getValue().shape(), 'c', e.getValue());
|
||||
} else {
|
||||
if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){
|
||||
e.getValue().assign(0.0);
|
||||
} else {
|
||||
|
@ -96,6 +99,7 @@ public class SameDiffDense extends SameDiffLayer {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
|
|
|
@ -50,6 +50,7 @@ import static org.junit.Assume.assumeTrue;
|
|||
|
||||
public class ValidateMKLDNN extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Test
|
||||
public void validateConvSubsampling() throws Exception {
|
||||
//Only run test if using nd4j-native backend
|
||||
|
@ -138,12 +139,14 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
|||
ConvolutionMode cm = ConvolutionMode.Truncate;
|
||||
|
||||
for (int minibatch : new int[]{1, 3}) {
|
||||
for (boolean b : new boolean[]{true, false}) {
|
||||
|
||||
inputSize[0] = minibatch;
|
||||
INDArray f = Nd4j.rand(Nd4j.defaultFloatingPointType(), inputSize);
|
||||
INDArray l = TestUtils.randomOneHot(minibatch, 10);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.updater(new Adam(0.01))
|
||||
.convolutionMode(cm)
|
||||
.seed(12345)
|
||||
|
@ -154,7 +157,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
|||
.padding(0, 0)
|
||||
.nOut(3)
|
||||
.build())
|
||||
.layer(new BatchNormalization.Builder().helperAllowFallback(false)/*.eps(0)*/.build())
|
||||
.layer(new BatchNormalization.Builder().useLogStd(b).helperAllowFallback(false)/*.eps(0)*/.build())
|
||||
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
|
||||
.kernelSize(kernel)
|
||||
.stride(stride)
|
||||
|
@ -186,6 +189,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
|||
LayerHelperValidationUtil.validateMLN(netWith, tc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test @Ignore //https://github.com/deeplearning4j/deeplearning4j/issues/7272
|
||||
public void validateLRN() {
|
||||
|
@ -265,6 +269,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void compareBatchNormBackward() throws Exception {
|
||||
assumeTrue(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native"));
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
|
|
|
@ -339,7 +339,13 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
|||
|
||||
INDArray outAct = net.output(in);
|
||||
|
||||
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
|
||||
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
|
||||
assertEquals(outExp, outAct);
|
||||
} else {
|
||||
boolean eq = outExp.equalsWithEps(outAct, 0.1);
|
||||
assertTrue(eq);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -24,101 +24,11 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A wrapper for a dataset to sample from.
|
||||
* This will randomly sample from the given dataset.
|
||||
* @author Adam GIbson
|
||||
*/
|
||||
public class SamplingDataSetIterator implements DataSetIterator {
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = -2700563801361726914L;
|
||||
private DataSet sampleFrom;
|
||||
private int batchSize;
|
||||
private int totalNumberSamples;
|
||||
private int numTimesSampled;
|
||||
@Getter
|
||||
private DataSetPreProcessor preProcessor;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sampleFrom the dataset to sample from
|
||||
* @param batchSize the batch size to sample
|
||||
* @param totalNumberSamples the sample size
|
||||
* @deprecated Use {@link org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator}
|
||||
*/
|
||||
@Deprecated
|
||||
public class SamplingDataSetIterator extends org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator {
|
||||
public SamplingDataSetIterator(DataSet sampleFrom, int batchSize, int totalNumberSamples) {
|
||||
super();
|
||||
this.sampleFrom = sampleFrom;
|
||||
this.batchSize = batchSize;
|
||||
this.totalNumberSamples = totalNumberSamples;
|
||||
super(sampleFrom, batchSize, totalNumberSamples);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return numTimesSampled < totalNumberSamples;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSet next() {
|
||||
DataSet ret = sampleFrom.sample(batchSize);
|
||||
numTimesSampled += batchSize;
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int inputColumns() {
|
||||
return sampleFrom.numInputs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int totalOutcomes() {
|
||||
return sampleFrom.numOutcomes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean resetSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean asyncSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
numTimesSampled = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int batch() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setPreProcessor(DataSetPreProcessor preProcessor) {
|
||||
this.preProcessor = preProcessor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getLabels() {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public DataSet next(int num) {
|
||||
DataSet ret = sampleFrom.sample(num);
|
||||
numTimesSampled++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ import java.util.concurrent.atomic.AtomicLong;
|
|||
|
||||
/**Implementation of the DeepWalk graph vectorization model, based on the paper
|
||||
* <i>DeepWalk: Online Learning of Social Representations</i> by Perozzi, Al-Rfou & Skiena (2014),
|
||||
* <a href="http://arxiv.org/abs/1403.6652">http://arxiv.org/abs/1403.6652</a><br>
|
||||
* <a href="https://arxiv.org/abs/1403.6652">https://arxiv.org/abs/1403.6652</a><br>
|
||||
* Similar to word2vec in nature, DeepWalk is an unsupervised learning algorithm that learns a vector representation
|
||||
* of each vertex in a graph. Vector representations are learned using walks (usually random walks) on the vertices in
|
||||
* the graph.<br>
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.hdf5.*;
|
||||
import org.bytedeco.javacpp.BytePointer;
|
||||
import org.bytedeco.javacpp.FloatPointer;
|
||||
import org.bytedeco.javacpp.Loader;
|
||||
|
@ -32,7 +33,6 @@ import java.lang.Exception;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.bytedeco.hdf5.*;
|
||||
import static org.bytedeco.hdf5.global.hdf5.*;
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
|
|
|
@ -40,6 +40,6 @@ public class InvalidKerasConfigurationException extends Exception {
|
|||
}
|
||||
|
||||
private static String appendDocumentationURL(String message) {
|
||||
return message + ". For more information, see http://deeplearning4j.org/model-import-keras.";
|
||||
return message + ". For more information, see http://deeplearning4j.org/docs/latest/keras-import-overview";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.exceptions;
|
|||
* is not currently supported.
|
||||
*
|
||||
* See <a href="https://deeplearning4j.org/docs/latest/keras-import-overview">https://deeplearning4j.org/docs/latest/keras-import-overview</a>
|
||||
* for more information and file an issue at <a href="http://github.com/deeplearning4j/deeplearning4j/issues">http://github.com/deeplearning4j/deeplearning4j/issues</a>.
|
||||
* for more information and file an issue at <a href="https://github.com/eclipse/deeplearning4j/issues">https://github.com/eclipse/deeplearning4j/issues</a>.
|
||||
*
|
||||
* @author dave@skymind.io
|
||||
*/
|
||||
|
@ -41,6 +41,6 @@ public class UnsupportedKerasConfigurationException extends Exception {
|
|||
}
|
||||
|
||||
private static String appendDocumentationURL(String message) {
|
||||
return message + ". Please file an issue at http://github.com/deeplearning4j/deeplearning4j/issues.";
|
||||
return message + ". Please file an issue at https://github.com/eclipse/deeplearning4j/issues.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.PReLULayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -27,9 +26,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.PReLUParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
@ -79,14 +77,12 @@ public class KerasPReLU extends KerasLayer {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, ALPHA_CONSTRAINT, conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, ALPHA_INIT,
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, ALPHA_INIT,
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
long[] axes = getSharedAxes(layerConfig);
|
||||
|
||||
PReLULayer.Builder builder = new PReLULayer.Builder().sharedAxes(axes)
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution)).name(layerName);
|
||||
.weightInit(init).name(layerName);
|
||||
if (weightConstraint != null){
|
||||
builder.constrainWeights(weightConstraint);
|
||||
}
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -83,15 +81,13 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -84,14 +82,13 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
|
||||
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction())
|
||||
.weightInit(init)
|
||||
.dilation(getDilationRate(layerConfig, 2, conf, true))
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
|
@ -30,7 +29,6 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
|
@ -30,10 +29,9 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -94,15 +92,13 @@ public class KerasConvolution1D extends KerasConvolution {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||
|
|
|
@ -21,14 +21,12 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -87,10 +85,8 @@ public class KerasConvolution2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -100,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
|
|||
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
|
|
|
@ -21,15 +21,13 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -88,10 +86,8 @@ public class KerasConvolution3D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 3, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -101,7 +97,7 @@ public class KerasConvolution3D extends KerasConvolution {
|
|||
Convolution3D.Builder builder = new Convolution3D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 3, conf, kerasMajorVersion))
|
||||
|
|
|
@ -20,14 +20,12 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -86,10 +84,8 @@ public class KerasDeconvolution2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -99,7 +95,7 @@ public class KerasDeconvolution2D extends KerasConvolution {
|
|||
Deconvolution2D.Builder builder = new Deconvolution2D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -30,9 +29,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
|
||||
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
|
@ -126,10 +124,8 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit depthWeightInit = depthWiseInit.getFirst();
|
||||
Distribution depthDistribution = depthWiseInit.getSecond();
|
||||
|
||||
val nIn = getNInFromConfig(previousLayers);
|
||||
|
||||
|
@ -152,7 +148,7 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
|
|||
.nIn(nIn)
|
||||
.nOut(nIn * depthMultiplier)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution))
|
||||
.weightInit(depthWiseInit)
|
||||
.depthMultiplier(depthMultiplier)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -20,7 +20,6 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
@ -28,9 +27,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
|
||||
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -93,17 +91,13 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
|
|||
|
||||
int depthMultiplier = getDepthMultiplier(layerConfig, conf);
|
||||
|
||||
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit depthWeightInit = depthWiseInit.getFirst();
|
||||
Distribution depthDistribution = depthWiseInit.getSecond();
|
||||
|
||||
Pair<WeightInit, Distribution> pointWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
IWeightInit pointWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
conf.getLAYER_FIELD_POINT_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit pointWeightInit = pointWiseInit.getFirst();
|
||||
Distribution pointDistribution = pointWiseInit.getSecond();
|
||||
|
||||
if (depthWeightInit != pointWeightInit || depthDistribution != pointDistribution)
|
||||
if ( !depthWiseInit.getClass().equals(pointWiseInit.getClass()) )
|
||||
if (enforceTrainingConfig)
|
||||
throw new UnsupportedKerasConfigurationException(
|
||||
"Specifying different initialization for depth- and point-wise weights not supported.");
|
||||
|
@ -126,7 +120,7 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
|
|||
SeparableConvolution2D.Builder builder = new SeparableConvolution2D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution))
|
||||
.weightInit(depthWiseInit)
|
||||
.depthMultiplier(depthMultiplier)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling3D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -95,15 +93,13 @@ public class KerasDense extends KerasLayer {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout).activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.hasBias(hasBias);
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.deeplearning4j.nn.conf.InputPreProcessor;
|
|||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
|||
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -26,7 +25,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -30,11 +29,10 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -104,12 +102,10 @@ public class KerasEmbedding extends KerasLayer {
|
|||
"on Embedding layers. Zero Masking for the Embedding layer only works with unidirectional LSTM for now."
|
||||
+ " If you want to have this behaviour for your imported model " +
|
||||
"in DL4J, apply masking as a pre-processing step to your input." +
|
||||
"See https://deeplearning4j.org/usingrnns#masking for more on this.");
|
||||
"See http://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent#masking for more on this.");
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -121,7 +117,7 @@ public class KerasEmbedding extends KerasLayer {
|
|||
.inferInputLength(inferInputLength)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout).activation(Activation.IDENTITY)
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization)
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -90,11 +88,8 @@ public class KerasLocallyConnected1D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 1, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
// TODO: take care of distribution and bias init
|
||||
//Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -104,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution {
|
|||
LocallyConnected1D.Builder builder = new LocallyConnected1D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit)
|
||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -39,9 +37,7 @@ import java.util.Map;
|
|||
import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.*;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils.getActivationFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils.getWeightInitFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getHasBiasFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getNOutFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.*;
|
||||
|
||||
|
||||
/**
|
||||
|
@ -92,11 +88,9 @@ public class KerasLocallyConnected2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
// TODO: take care of distribution and bias init
|
||||
//Distribution distribution = init.getSecond();
|
||||
// TODO: take care of bias init
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -106,7 +100,7 @@ public class KerasLocallyConnected2D extends KerasConvolution {
|
|||
LocallyConnected2D.Builder builder = new LocallyConnected2D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit)
|
||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
|
|
@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
|
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -151,15 +150,11 @@ public class KerasLSTM extends KerasLayer {
|
|||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
super(layerConfig, enforceTrainingConfig);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit recurrentWeightInit = recurrentInit.getFirst();
|
||||
Distribution recurrentDistribution = recurrentInit.getSecond();
|
||||
|
||||
boolean hasBias = getHasBiasFromConfig(layerConfig, conf);
|
||||
|
||||
|
@ -186,8 +181,8 @@ public class KerasLSTM extends KerasLayer {
|
|||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution))
|
||||
.weightInit(init)
|
||||
.weightInitRecurrent(recurrentInit)
|
||||
.biasInit(0.0) // TODO: this is incorrect
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
|
@ -34,7 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
|
@ -124,15 +123,11 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
super(layerConfig, enforceTrainingConfig);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit recurrentWeightInit = recurrentInit.getFirst();
|
||||
Distribution recurrentDistribution = recurrentInit.getSecond();
|
||||
|
||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||
this.returnSequences = (Boolean) innerConfig.get(conf.getLAYER_FIELD_RETURN_SEQUENCES());
|
||||
|
@ -154,8 +149,8 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution))
|
||||
.weightInit(init)
|
||||
.weightInitRecurrent(recurrentInit)
|
||||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
|
|
|
@ -20,9 +20,7 @@ import com.google.gson.Gson;
|
|||
import com.google.gson.reflect.TypeToken;
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -31,7 +29,6 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -22,9 +22,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
|||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
/**
|
||||
|
|
|
@ -19,17 +19,15 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessors;
|
|||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
|
|
|
@ -1,28 +1,15 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.utils;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.validation.Nd4jCommonValidator;
|
||||
import org.nd4j.validation.ValidationResult;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
|
||||
/**
|
||||
* A utility for validating serialized Keras sequential and functional models for import into DL4J
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -21,8 +21,7 @@ import org.deeplearning4j.nn.conf.distribution.*;
|
|||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -42,7 +41,7 @@ public class KerasInitilizationUtils {
|
|||
* @return DL4J weight initialization enum
|
||||
* @see WeightInit
|
||||
*/
|
||||
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit,
|
||||
public static IWeightInit mapWeightInitialization(String kerasInit,
|
||||
KerasLayerConfiguration conf,
|
||||
Map<String, Object> initConfig,
|
||||
int kerasMajorVersion)
|
||||
|
@ -50,68 +49,63 @@ public class KerasInitilizationUtils {
|
|||
|
||||
|
||||
// TODO: Identity and VarianceScaling need "scale" factor
|
||||
WeightInit init = null;
|
||||
Distribution dist = null;
|
||||
if (kerasInit != null) {
|
||||
if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_GLOROT_NORMAL_ALIAS())) {
|
||||
init = WeightInit.XAVIER;
|
||||
return WeightInit.XAVIER.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_GLOROT_UNIFORM_ALIAS())) {
|
||||
init = WeightInit.XAVIER_UNIFORM;
|
||||
return WeightInit.XAVIER_UNIFORM.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_LECUN_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_LECUN_NORMAL_ALIAS())) {
|
||||
init = WeightInit.LECUN_NORMAL;
|
||||
return WeightInit.LECUN_NORMAL.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_LECUN_UNIFORM_ALIAS())) {
|
||||
init = WeightInit.LECUN_UNIFORM;
|
||||
return WeightInit.LECUN_UNIFORM.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_HE_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_HE_NORMAL_ALIAS())) {
|
||||
init = WeightInit.RELU;
|
||||
return WeightInit.RELU.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_HE_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_HE_UNIFORM_ALIAS())) {
|
||||
init = WeightInit.RELU_UNIFORM;
|
||||
return WeightInit.RELU_UNIFORM.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_ONE()) ||
|
||||
kerasInit.equals(conf.getINIT_ONES()) ||
|
||||
kerasInit.equals(conf.getINIT_ONES_ALIAS())) {
|
||||
init = WeightInit.ONES;
|
||||
return WeightInit.ONES.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_ZERO()) ||
|
||||
kerasInit.equals(conf.getINIT_ZEROS()) ||
|
||||
kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) {
|
||||
init = WeightInit.ZERO;
|
||||
return WeightInit.ZERO.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL());
|
||||
double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL());
|
||||
dist = new UniformDistribution(minVal, maxVal);
|
||||
return new WeightInitDistribution(new UniformDistribution(minVal, maxVal));
|
||||
} else {
|
||||
double scale = 0.05;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
dist = new UniformDistribution(-scale, scale);
|
||||
return new WeightInitDistribution(new UniformDistribution(-scale, scale));
|
||||
}
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
} else if (kerasInit.equals(conf.getINIT_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
|
||||
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
|
||||
dist = new NormalDistribution(mean, stdDev);
|
||||
return new WeightInitDistribution(new NormalDistribution(mean, stdDev));
|
||||
} else {
|
||||
double scale = 0.05;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
dist = new NormalDistribution(0, scale);
|
||||
return new WeightInitDistribution(new NormalDistribution(0, scale));
|
||||
}
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
|
||||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
|
||||
double value = (double) initConfig.get(conf.getLAYER_FIELD_INIT_VALUE());
|
||||
dist = new ConstantDistribution(value);
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
return new WeightInitDistribution(new ConstantDistribution(value));
|
||||
} else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) ||
|
||||
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
|
@ -121,34 +115,38 @@ public class KerasInitilizationUtils {
|
|||
} catch (Exception e) {
|
||||
gain = (int) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
|
||||
}
|
||||
dist = new OrthogonalDistribution(gain);
|
||||
return new WeightInitDistribution(new OrthogonalDistribution(gain));
|
||||
} else {
|
||||
double scale = 1.1;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
dist = new OrthogonalDistribution(scale);
|
||||
return new WeightInitDistribution(new OrthogonalDistribution(scale));
|
||||
}
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
} else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
|
||||
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
|
||||
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
|
||||
dist = new TruncatedNormalDistribution(mean, stdDev);
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
return new WeightInitDistribution(new TruncatedNormalDistribution(mean, stdDev));
|
||||
} else if (kerasInit.equals(conf.getINIT_IDENTITY()) ||
|
||||
kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
|
||||
if (gain != 1.)
|
||||
log.warn("Scaled identity weight init not supported, setting gain=1");
|
||||
if (gain != 1.0)
|
||||
if (gain != 1.0) {
|
||||
return new WeightInitIdentity(gain);
|
||||
} else {
|
||||
return new WeightInitIdentity();
|
||||
}
|
||||
} else {
|
||||
double scale = 1.;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
if (scale != 1.)
|
||||
log.warn("Scaled identity weight init not supported, setting scale=1");
|
||||
if (scale != 1.0) {
|
||||
return new WeightInitIdentity(scale);
|
||||
} else {
|
||||
return new WeightInitIdentity();
|
||||
}
|
||||
}
|
||||
init = WeightInit.IDENTITY;
|
||||
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
|
||||
double scale;
|
||||
try {
|
||||
|
@ -156,32 +154,27 @@ public class KerasInitilizationUtils {
|
|||
} catch (Exception e) {
|
||||
scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
}
|
||||
if (scale != 1.)
|
||||
log.warn("Scaled identity weight init not supported, setting scale=1");
|
||||
String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE());
|
||||
String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION());
|
||||
switch (mode) {
|
||||
case "fan_in":
|
||||
if (distribution.equals("normal")) {
|
||||
init = WeightInit.VAR_SCALING_NORMAL_FAN_IN;
|
||||
return new WeightInitVarScalingNormalFanIn(scale);
|
||||
} else {
|
||||
init = WeightInit.VAR_SCALING_UNIFORM_FAN_IN;
|
||||
return new WeightInitVarScalingUniformFanIn(scale);
|
||||
}
|
||||
break;
|
||||
case "fan_out":
|
||||
if (distribution.equals("normal")) {
|
||||
init = WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
|
||||
return new WeightInitVarScalingNormalFanOut(scale);
|
||||
} else {
|
||||
init = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT;
|
||||
return new WeightInitVarScalingUniformFanOut(scale);
|
||||
}
|
||||
break;
|
||||
case "fan_avg":
|
||||
if (distribution.equals("normal")) {
|
||||
init = WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
|
||||
return new WeightInitVarScalingNormalFanAvg(scale);
|
||||
} else {
|
||||
init = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG;
|
||||
return new WeightInitVarScalingUniformFanAvg(scale);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " +
|
||||
"fan_in, fan_out or fan_avg");
|
||||
|
@ -190,7 +183,7 @@ public class KerasInitilizationUtils {
|
|||
throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
|
||||
}
|
||||
}
|
||||
return new Pair<>(init, dist);
|
||||
throw new IllegalStateException("Error getting Keras weight initialization");
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -202,7 +195,7 @@ public class KerasInitilizationUtils {
|
|||
* @throws InvalidKerasConfigurationException Invalid Keras config
|
||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
|
||||
*/
|
||||
public static Pair<WeightInit, Distribution> getWeightInitFromConfig(Map<String, Object> layerConfig, String initField,
|
||||
public static IWeightInit getWeightInitFromConfig(Map<String, Object> layerConfig, String initField,
|
||||
boolean enforceTrainingConfig,
|
||||
KerasLayerConfiguration conf,
|
||||
int kerasMajorVersion)
|
||||
|
@ -225,14 +218,14 @@ public class KerasInitilizationUtils {
|
|||
throw new UnsupportedKerasConfigurationException("Incomplete initialization class");
|
||||
}
|
||||
}
|
||||
Pair<WeightInit, Distribution> init;
|
||||
IWeightInit init;
|
||||
try {
|
||||
init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion);
|
||||
} catch (UnsupportedKerasConfigurationException e) {
|
||||
if (enforceTrainingConfig)
|
||||
throw e;
|
||||
else {
|
||||
init = new Pair<>(WeightInit.XAVIER, null);
|
||||
init = new WeightInitXavier();
|
||||
log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
|
@ -25,7 +24,6 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
public class KerasTestUtils {
|
||||
|
|
|
@ -22,8 +22,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.util.Nd4jValidator;
|
||||
import org.nd4j.resources.Resources;
|
||||
import org.nd4j.validation.ValidationResult;
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.datavec.api.records.reader.SequenceRecordReader;
|
|||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||
import org.datavec.api.split.NumberedFileInputSplit;
|
||||
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
|
||||
|
||||
import org.deeplearning4j.nn.layers.recurrent.LSTM;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
|
@ -30,7 +29,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
|||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
|
|
@ -30,11 +30,9 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Arrays;
|
||||
|
|
|
@ -25,6 +25,8 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitIdentity;
|
||||
import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanIn;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
@ -94,11 +96,11 @@ public class KerasInitilizationTest extends BaseDL4JTest {
|
|||
WeightInit.RELU_UNIFORM.getWeightInitFunction(),
|
||||
WeightInit.ONES.getWeightInitFunction(),
|
||||
WeightInit.ZERO.getWeightInitFunction(),
|
||||
WeightInit.IDENTITY.getWeightInitFunction(),
|
||||
new WeightInitIdentity(0.2),
|
||||
WeightInit.DISTRIBUTION.getWeightInitFunction(new NormalDistribution(mean, stdDev)),
|
||||
WeightInit.DISTRIBUTION.getWeightInitFunction(new OrthogonalDistribution(gain)),
|
||||
WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(value)),
|
||||
WeightInit.VAR_SCALING_NORMAL_FAN_IN.getWeightInitFunction()};
|
||||
new WeightInitVarScalingNormalFanIn(0.2)};
|
||||
}
|
||||
|
||||
private Distribution[] dl4jDistributions() {
|
||||
|
|
|
@ -17,22 +17,16 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.configurations;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
|
|
@ -24,22 +24,19 @@ import org.deeplearning4j.eval.ROCMultiClass;
|
|||
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LSTM;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
|
||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
|
||||
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
@ -47,27 +44,25 @@ import org.junit.rules.TemporaryFolder;
|
|||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URL;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Unit tests for end-to-end Keras model import.
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
|||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
|
||||
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||||
|
@ -31,11 +30,8 @@ import org.junit.Test;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
|
||||
/**
|
||||
* Import previously stored YOLO9000 Keras net from https://github.com/allanzelener/YAD2K.
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.junit.Ignore;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
|
|
@ -27,16 +27,11 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousC
|
|||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -28,9 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu
|
|||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
@ -39,7 +36,6 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -16,13 +16,11 @@
|
|||
|
||||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D;
|
||||
import org.junit.Test;
|
||||
|
||||
|
|
|
@ -30,15 +30,11 @@ import org.deeplearning4j.nn.weights.IWeightInit;
|
|||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -17,18 +17,14 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling1D;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
|
|
@ -17,13 +17,11 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
|
|
@ -17,12 +17,10 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D;
|
||||
import org.junit.Test;
|
||||
|
||||
|
|
|
@ -26,16 +26,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -24,10 +24,12 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
|
|
@ -24,11 +24,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
|
|
@ -26,11 +26,7 @@ import org.junit.Test;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
|
|||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
|
@ -31,10 +30,8 @@ import org.junit.Test;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
|
|
|
@ -27,15 +27,14 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling;
|
|||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
|
|
|
@ -33,14 +33,13 @@ import org.deeplearning4j.nn.weights.IWeightInit;
|
|||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -16,15 +16,12 @@
|
|||
|
||||
package org.deeplearning4j.nn.modelimport.keras.optimizers;
|
||||
|
||||
import org.deeplearning4j.config.DL4JSystemProperties;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
||||
import org.deeplearning4j.util.DL4JFileUtils;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -32,8 +29,6 @@ import java.io.InputStream;
|
|||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
|
||||
import static java.io.File.createTempFile;
|
||||
|
||||
public class OptimizerImport extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -18,9 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence;
|
|||
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.IOException;
|
||||
|
|
|
@ -19,15 +19,11 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text;
|
|||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Import Keras Tokenizer
|
||||
|
|
|
@ -20,7 +20,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.junit.Test;
|
|||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
|
|
@ -77,71 +77,6 @@
|
|||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>${google.protobuf.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
<version>${jodatime.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>${commons-lang3.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.hibernate</groupId>
|
||||
<artifactId>hibernate-validator</artifactId>
|
||||
<version>${hibernate.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.yaml</groupId>
|
||||
<artifactId>snakeyaml</artifactId>
|
||||
<version>${snakeyaml.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jdk8</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe</groupId>
|
||||
<artifactId>config</artifactId>
|
||||
<version>${typesafe.config.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-java_2.11</artifactId>
|
||||
|
|
|
@ -31,21 +31,6 @@
|
|||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents</groupId>
|
||||
<artifactId>httpclient</artifactId>
|
||||
<version>${httpclient.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents</groupId>
|
||||
<artifactId>httpcore</artifactId>
|
||||
<version>${httpcore.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents</groupId>
|
||||
<artifactId>httpmime</artifactId>
|
||||
<version>${httpmime.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
|
|
|
@ -29,7 +29,7 @@ import static java.lang.Math.max;
|
|||
* QuadTree: <a href="http://en.wikipedia.org/wiki/Quadtree">http://en.wikipedia.org/wiki/Quadtree</a>
|
||||
*
|
||||
* Reference impl based on the paper by:
|
||||
* <a href="http://arxiv.org/pdf/1301.3342v2.pdf">http://arxiv.org/pdf/1301.3342v2.pdf</a>
|
||||
* <a href="https://arxiv.org/pdf/1301.3342v2.pdf">https://arxiv.org/pdf/1301.3342v2.pdf</a>
|
||||
*
|
||||
* Primarily focused on 2 dimensions, may expand later if there's a reason.
|
||||
*
|
||||
|
|
|
@ -86,7 +86,7 @@ public class MathUtils {
|
|||
|
||||
|
||||
/**
|
||||
* See: http://stackoverflow.com/questions/466204/rounding-off-to-nearest-power-of-2
|
||||
* See: https://stackoverflow.com/questions/466204/rounding-off-to-nearest-power-of-2
|
||||
* @param v the number to getFromOrigin the next power of 2 for
|
||||
* @return the next power of 2 for the passed in value
|
||||
*/
|
||||
|
|
|
@ -52,12 +52,6 @@
|
|||
<artifactId>deeplearning4j-nlp</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nutz</groupId>
|
||||
<artifactId>nutz</artifactId>
|
||||
<version>1.r.58</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nlpcn</groupId>
|
||||
<artifactId>nlp-lang</artifactId>
|
||||
|
|
|
@ -33,26 +33,6 @@
|
|||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>commons-logging</groupId>
|
||||
<artifactId>commons-logging</artifactId>
|
||||
<version>${commons-logging.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-core</artifactId>
|
||||
<version>${spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-context</artifactId>
|
||||
<version>${spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-beans</artifactId>
|
||||
<version>${spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.cleartk</groupId>
|
||||
<artifactId>cleartk-snowball</artifactId>
|
||||
|
|
|
@ -54,11 +54,6 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.objenesis</groupId>
|
||||
<artifactId>objenesis</artifactId>
|
||||
<version>${objenesis.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
|
@ -66,16 +61,6 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- TSNE -->
|
||||
<!-- (Previously: dropwizard deps) -->
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-jackson</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
|
|
|
@ -42,7 +42,7 @@ import java.util.*;
|
|||
* Instead of rand walks, this walker produces walks based on number of edges coming into each node.
|
||||
* This allows you to build walks filtering too rare nodes, or too popular nodes, depending on your demands.
|
||||
*
|
||||
* Original DeepWalk paper: <a href="http://arxiv.org/pdf/1403.6652v2">http://arxiv.org/pdf/1403.6652v2</a>
|
||||
* Original DeepWalk paper: <a href="https://arxiv.org/pdf/1403.6652v2">https://arxiv.org/pdf/1403.6652v2</a>
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class PopularityWalker<T extends SequenceElement> extends RandomWalker<T> implements GraphWalker<T> {
|
||||
|
|
|
@ -37,7 +37,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
/**
|
||||
* This is Random-based walker for SequenceVectors-based DeepWalk implementation
|
||||
*
|
||||
* Original DeepWalk paper: <a href="http://arxiv.org/pdf/1403.6652v2">http://arxiv.org/pdf/1403.6652v2</a>
|
||||
* Original DeepWalk paper: <a href="https://arxiv.org/pdf/1403.6652v2">https://arxiv.org/pdf/1403.6652v2</a>
|
||||
*
|
||||
* @author AlexDBlack
|
||||
* @author raver119@gmail.com
|
||||
|
|
|
@ -52,7 +52,7 @@ package org.deeplearning4j.nn.conf;
|
|||
* </ul>
|
||||
* Thus, the l2 norm of the scaled gradients will not exceed the specified threshold, though may be smaller than it<br>
|
||||
* See: Pascanu, Mikolov, Bengio (2012), <i>On the difficulty of training Recurrent Neural Networks</i>,
|
||||
* <a href="http://arxiv.org/abs/1211.5063">http://arxiv.org/abs/1211.5063</a><br>
|
||||
* <a href="https://arxiv.org/abs/1211.5063">https://arxiv.org/abs/1211.5063</a><br>
|
||||
* Threshold for clipping can be set in Layer configuration, using gradientNormalizationThreshold(double threshold)
|
||||
* </p>
|
||||
*
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|||
|
||||
/**
|
||||
* Orthogonal distribution, with gain parameter.<br>
|
||||
* See <a href="http://arxiv.org/abs/1312.6120">http://arxiv.org/abs/1312.6120</a> for details
|
||||
* See <a href="https://arxiv.org/abs/1312.6120">https://arxiv.org/abs/1312.6120</a> for details
|
||||
*
|
||||
*/
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue