Fixes and pre-release QA (#51)

* #8395 Keras import - support scaled identity weight init

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More Keras scaled weight init fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8352 Deprecate duplicate SamplingDataSetIterator class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove /O2 optimization for faster CUDA build

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweak regression test precision for CUDA

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix edge cases for buffer creation

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update MKLDNN validation tests to new helper enable/disable settings

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Delete debugging class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* MKLDNN test - add proper skip for CUDA backend

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Align WeightInitUtil with weight init classes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for SameDiff test layers weight init when using IWeightInit classes

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-16 17:04:29 +11:00 committed by GitHub
parent 1780dcc883
commit 09a827fb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
85 changed files with 378 additions and 574 deletions

View File

@ -35,6 +35,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import static org.junit.Assert.*;
@ -63,6 +64,30 @@ public class LayerHelperValidationUtil {
private DataSetIterator data;
}
public static void disableCppHelpers(){
try {
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method m = c.getMethod("getInstance");
Object instance = m.invoke(null);
Method m2 = c.getMethod("allowHelpers", boolean.class);
m2.invoke(instance, false);
} catch (Throwable t){
throw new RuntimeException(t);
}
}
public static void enableCppHelpers(){
try{
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method m = c.getMethod("getInstance");
Object instance = m.invoke(null);
Method m2 = c.getMethod("allowHelpers", boolean.class);
m2.invoke(instance, true);
} catch (Throwable t){
throw new RuntimeException(t);
}
}
public static void validateMLN(MultiLayerNetwork netOrig, TestCase t){
assertNotNull(t.getAllowHelpersForClasses());
assertFalse(t.getAllowHelpersForClasses().isEmpty());
@ -95,7 +120,13 @@ public class LayerHelperValidationUtil {
for (boolean train : new boolean[]{false, true}) {
assertEquals(net1NoHelper.params(), net2With.params());
String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
List<INDArray> ff1 = net1NoHelper.feedForward(t.getFeatures(), train);
List<INDArray> ff1;
try {
disableCppHelpers();
ff1 = net1NoHelper.feedForward(t.getFeatures(), train);
} finally {
enableCppHelpers();
}
List<INDArray> ff2 = net2With.feedForward(t.getFeatures(), train);
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());
Collections.sort(paramKeys);
@ -131,7 +162,13 @@ public class LayerHelperValidationUtil {
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
}
INDArray out1 = net1NoHelper.output(t.getFeatures(), train);
INDArray out1;
try {
disableCppHelpers();
out1 = net1NoHelper.output(t.getFeatures(), train);
} finally {
enableCppHelpers();
}
INDArray out2 = net2With.output(t.getFeatures(), train);
INDArray relError = relError(out1, out2, t.getMinAbsError());
double maxRE = relError.maxNumber().doubleValue();
@ -148,7 +185,13 @@ public class LayerHelperValidationUtil {
Preconditions.checkNotNull(t.getLabels(), "Labels are not set (null)");
log.info("Validation - checking scores");
double s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels()));
double s1;
try {
disableCppHelpers();
s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels()));
} finally {
enableCppHelpers();
}
double s2 = net2With.score(new DataSet(t.getFeatures(), t.getLabels()));
double re = relError(s1, s2);
@ -168,7 +211,12 @@ public class LayerHelperValidationUtil {
net2With.setInput(t.getFeatures());
net2With.setLabels(t.getLabels());
try {
disableCppHelpers();
net1NoHelper.computeGradientAndScore();
} finally {
enableCppHelpers();
}
net2With.computeGradientAndScore();
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());

View File

@ -1,107 +0,0 @@
package org.deeplearning4j;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.lang.reflect.Field;
import static junit.framework.TestCase.*;
public class TestBatchNormBp {
@Test
public void test(){
Nd4j.getRandom().setSeed(12345);
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
double e = 1e-5;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(in, mean, var, eps, gamma, beta)
.addIntegerArguments(
1, //Apply scale
1, //Apply beta
1) //Axis (NCHW)
.addFloatingPointArguments(e)
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
.build();
Nd4j.exec(op);
System.out.println(dLdIn);
}
@Test
public void compareImpls() throws Exception {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
INDArray var = in.var(0, 2, 3).reshape(1,3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
double e = 1e-3;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.inferenceWorkspaceMode(WorkspaceMode.NONE)
.trainingWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
assertNotNull(bn.getHelper());
Field f = bn.getClass().getDeclaredField("helper");
f.setAccessible(true);
f.set(bn, null);
assertNull(bn.getHelper());
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
INDArray dldin_dl4j = p.getSecond();
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
}
}

View File

@ -70,9 +70,20 @@ public class MinimalSameDiffDense extends SameDiffLayer {
@Override
public void initializeParameters(Map<String, INDArray> params) {
String b = DefaultParamInitializer.BIAS_KEY;
if(paramWeightInit != null && paramWeightInit.containsKey(b)){
paramWeightInit.get(b).init(nIn, nOut, params.get(b).shape(), 'c', params.get(b));
} else {
params.get(DefaultParamInitializer.BIAS_KEY).assign(0);
}
String w = DefaultParamInitializer.WEIGHT_KEY;
if(paramWeightInit != null && paramWeightInit.containsKey(w)){
paramWeightInit.get(w).init(nIn, nOut, params.get(w).shape(), 'c', params.get(w));
} else {
initWeights(nIn, nOut, weightInit, params.get(DefaultParamInitializer.WEIGHT_KEY));
}
}
//OPTIONAL methods:
// public void setNIn(InputType inputType, boolean override)

View File

@ -109,17 +109,21 @@ public class SameDiffConv extends SameDiffLayer {
@Override
public void initializeParameters(Map<String, INDArray> params) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
double fanIn = nIn * kernel[0] * kernel[1];
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
for (Map.Entry<String, INDArray> e : params.entrySet()) {
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
} else {
if (ConvolutionParamInitializer.BIAS_KEY.equals(e.getKey())) {
e.getValue().assign(0);
} else {
double fanIn = nIn * kernel[0] * kernel[1];
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
}
}
}
}
}
@Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {

View File

@ -88,6 +88,9 @@ public class SameDiffDense extends SameDiffLayer {
@Override
public void initializeParameters(Map<String,INDArray> params){
for(Map.Entry<String,INDArray> e : params.entrySet()){
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
paramWeightInit.get(e.getKey()).init(nIn, nOut, e.getValue().shape(), 'c', e.getValue());
} else {
if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){
e.getValue().assign(0.0);
} else {
@ -96,6 +99,7 @@ public class SameDiffDense extends SameDiffLayer {
}
}
}
}
@Override
public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {

View File

@ -50,6 +50,7 @@ import static org.junit.Assume.assumeTrue;
public class ValidateMKLDNN extends BaseDL4JTest {
@Test
public void validateConvSubsampling() throws Exception {
//Only run test if using nd4j-native backend
@ -268,6 +269,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
@Test
public void compareBatchNormBackward() throws Exception {
assumeTrue(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native"));
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);

View File

@ -339,7 +339,13 @@ public class RegressionTest100b4 extends BaseDL4JTest {
INDArray outAct = net.output(in);
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
assertEquals(outExp, outAct);
} else {
boolean eq = outExp.equalsWithEps(outAct, 0.1);
assertTrue(eq);
}
}
@Test

View File

@ -24,101 +24,11 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List;
/**
* A wrapper for a dataset to sample from.
* This will randomly sample from the given dataset.
* @author Adam GIbson
*/
public class SamplingDataSetIterator implements DataSetIterator {
/**
*
*/
private static final long serialVersionUID = -2700563801361726914L;
private DataSet sampleFrom;
private int batchSize;
private int totalNumberSamples;
private int numTimesSampled;
@Getter
private DataSetPreProcessor preProcessor;
/**
*
* @param sampleFrom the dataset to sample from
* @param batchSize the batch size to sample
* @param totalNumberSamples the sample size
* @deprecated Use {@link org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator}
*/
@Deprecated
public class SamplingDataSetIterator extends org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator {
public SamplingDataSetIterator(DataSet sampleFrom, int batchSize, int totalNumberSamples) {
super();
this.sampleFrom = sampleFrom;
this.batchSize = batchSize;
this.totalNumberSamples = totalNumberSamples;
super(sampleFrom, batchSize, totalNumberSamples);
}
@Override
public boolean hasNext() {
return numTimesSampled < totalNumberSamples;
}
@Override
public DataSet next() {
DataSet ret = sampleFrom.sample(batchSize);
numTimesSampled += batchSize;
return ret;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
@Override
public int inputColumns() {
return sampleFrom.numInputs();
}
@Override
public int totalOutcomes() {
return sampleFrom.numOutcomes();
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public void reset() {
numTimesSampled = 0;
}
@Override
public int batch() {
return batchSize;
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
}
@Override
public List<String> getLabels() {
return null;
}
@Override
public DataSet next(int num) {
DataSet ret = sampleFrom.sample(num);
numTimesSampled++;
return ret;
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.hdf5.*;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader;
@ -32,7 +33,6 @@ import java.lang.Exception;
import java.util.ArrayList;
import java.util.List;
import org.bytedeco.hdf5.*;
import static org.bytedeco.hdf5.global.hdf5.*;
/**

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.nn.modelimport.keras;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.PReLULayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -27,9 +26,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.PReLUParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.HashMap;
@ -79,14 +77,12 @@ public class KerasPReLU extends KerasLayer {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, ALPHA_CONSTRAINT, conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, ALPHA_INIT,
IWeightInit init = getWeightInitFromConfig(layerConfig, ALPHA_INIT,
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
long[] axes = getSharedAxes(layerConfig);
PReLULayer.Builder builder = new PReLULayer.Builder().sharedAxes(axes)
.weightInit(weightInit.getWeightInitFunction(distribution)).name(layerName);
.weightInit(init).name(layerName);
if (weightConstraint != null){
builder.constrainWeights(weightConstraint);
}

View File

@ -17,14 +17,12 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.weights.IWeightInit;
import java.util.Map;
@ -83,15 +81,13 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInit(init)
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -17,14 +17,12 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.weights.IWeightInit;
import java.util.Map;
@ -84,14 +82,13 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction())
.weightInit(init)
.dilation(getDilationRate(layerConfig, 2, conf, true))
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
@ -30,7 +29,6 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;

View File

@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
@ -30,10 +29,9 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap;
import java.util.Map;
@ -94,15 +92,13 @@ public class KerasConvolution1D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])

View File

@ -21,14 +21,12 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.weights.IWeightInit;
import java.util.Map;
@ -87,10 +85,8 @@ public class KerasConvolution2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -100,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))

View File

@ -21,15 +21,13 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.weights.IWeightInit;
import java.util.Map;
@ -88,10 +86,8 @@ public class KerasConvolution3D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 3, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -101,7 +97,7 @@ public class KerasConvolution3D extends KerasConvolution {
Convolution3D.Builder builder = new Convolution3D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 3, conf, kerasMajorVersion))

View File

@ -20,14 +20,12 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.weights.IWeightInit;
import java.util.Map;
@ -86,10 +84,8 @@ public class KerasDeconvolution2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -99,7 +95,7 @@ public class KerasDeconvolution2D extends KerasConvolution {
Deconvolution2D.Builder builder = new Deconvolution2D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -30,9 +29,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.Collections;
import java.util.HashMap;
@ -126,10 +124,8 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig,
IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit depthWeightInit = depthWiseInit.getFirst();
Distribution depthDistribution = depthWiseInit.getSecond();
val nIn = getNInFromConfig(previousLayers);
@ -152,7 +148,7 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
.nIn(nIn)
.nOut(nIn * depthMultiplier)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution))
.weightInit(depthWiseInit)
.depthMultiplier(depthMultiplier)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -20,7 +20,6 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
@ -28,9 +27,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap;
import java.util.Map;
@ -93,17 +91,13 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
int depthMultiplier = getDepthMultiplier(layerConfig, conf);
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig,
IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit depthWeightInit = depthWiseInit.getFirst();
Distribution depthDistribution = depthWiseInit.getSecond();
Pair<WeightInit, Distribution> pointWiseInit = getWeightInitFromConfig(layerConfig,
IWeightInit pointWiseInit = getWeightInitFromConfig(layerConfig,
conf.getLAYER_FIELD_POINT_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit pointWeightInit = pointWiseInit.getFirst();
Distribution pointDistribution = pointWiseInit.getSecond();
if (depthWeightInit != pointWeightInit || depthDistribution != pointDistribution)
if ( !depthWiseInit.getClass().equals(pointWiseInit.getClass()) )
if (enforceTrainingConfig)
throw new UnsupportedKerasConfigurationException(
"Specifying different initialization for depth- and point-wise weights not supported.");
@ -126,7 +120,7 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
SeparableConvolution2D.Builder builder = new SeparableConvolution2D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution))
.weightInit(depthWiseInit)
.depthMultiplier(depthMultiplier)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.Upsampling3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap;
import java.util.Map;
@ -95,15 +93,13 @@ public class KerasDense extends KerasLayer {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout).activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInit(init)
.biasInit(0.0)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.hasBias(hasBias);

View File

@ -22,7 +22,6 @@ import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -26,7 +25,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.List;
import java.util.Map;

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -30,11 +29,10 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap;
import java.util.Map;
@ -106,10 +104,8 @@ public class KerasEmbedding extends KerasLayer {
"in DL4J, apply masking as a pre-processing step to your input." +
"See http://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent#masking for more on this.");
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), conf, kerasMajorVersion);
@ -121,7 +117,7 @@ public class KerasEmbedding extends KerasLayer {
.inferInputLength(inferInputLength)
.nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout).activation(Activation.IDENTITY)
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInit(init)
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization)

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap;
import java.util.Map;
@ -90,11 +88,8 @@ public class KerasLocallyConnected1D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 1, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
// TODO: take care of distribution and bias init
//Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -104,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution {
LocallyConnected1D.Builder builder = new LocallyConnected1D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getActivationFromConfig(layerConfig, conf))
.weightInit(weightInit)
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap;
import java.util.Map;
@ -39,9 +37,7 @@ import java.util.Map;
import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.*;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils.getActivationFromConfig;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils.getWeightInitFromConfig;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getHasBiasFromConfig;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getNOutFromConfig;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.*;
/**
@ -92,11 +88,9 @@ public class KerasLocallyConnected2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
// TODO: take care of distribution and bias init
//Distribution distribution = init.getSecond();
// TODO: take care of bias init
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -106,7 +100,7 @@ public class KerasLocallyConnected2D extends KerasConvolution {
LocallyConnected2D.Builder builder = new LocallyConnected2D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getActivationFromConfig(layerConfig, conf))
.weightInit(weightInit)
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))

View File

@ -31,7 +31,6 @@ import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

View File

@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -151,15 +150,11 @@ public class KerasLSTM extends KerasLayer {
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit recurrentWeightInit = recurrentInit.getFirst();
Distribution recurrentDistribution = recurrentInit.getSecond();
boolean hasBias = getHasBiasFromConfig(layerConfig, conf);
@ -186,8 +181,8 @@ public class KerasLSTM extends KerasLayer {
.nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution))
.weightInit(init)
.weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
@ -34,7 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
@ -124,15 +123,11 @@ public class KerasSimpleRnn extends KerasLayer {
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit recurrentWeightInit = recurrentInit.getFirst();
Distribution recurrentDistribution = recurrentInit.getSecond();
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
this.returnSequences = (Boolean) innerConfig.get(conf.getLAYER_FIELD_RETURN_SEQUENCES());
@ -154,8 +149,8 @@ public class KerasSimpleRnn extends KerasLayer {
.nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution))
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution))
.weightInit(init)
.weightInitRecurrent(recurrentInit)
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);

View File

@ -20,9 +20,7 @@ import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import lombok.Data;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -31,7 +29,6 @@ import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

View File

@ -22,9 +22,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**

View File

@ -19,17 +19,15 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessors;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -20,9 +20,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -1,28 +1,15 @@
package org.deeplearning4j.nn.modelimport.keras.utils;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.validation.Nd4jCommonValidator;
import org.nd4j.validation.ValidationResult;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
/**
* A utility for validating serialized Keras sequential and functional models for import into DL4J

View File

@ -21,7 +21,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.*;
import java.util.Map;

View File

@ -21,8 +21,7 @@ import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.weights.*;
import java.util.HashMap;
import java.util.Map;
@ -42,7 +41,7 @@ public class KerasInitilizationUtils {
* @return DL4J weight initialization enum
* @see WeightInit
*/
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit,
public static IWeightInit mapWeightInitialization(String kerasInit,
KerasLayerConfiguration conf,
Map<String, Object> initConfig,
int kerasMajorVersion)
@ -50,68 +49,63 @@ public class KerasInitilizationUtils {
// TODO: Identity and VarianceScaling need "scale" factor
WeightInit init = null;
Distribution dist = null;
if (kerasInit != null) {
if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL()) ||
kerasInit.equals(conf.getINIT_GLOROT_NORMAL_ALIAS())) {
init = WeightInit.XAVIER;
return WeightInit.XAVIER.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM()) ||
kerasInit.equals(conf.getINIT_GLOROT_UNIFORM_ALIAS())) {
init = WeightInit.XAVIER_UNIFORM;
return WeightInit.XAVIER_UNIFORM.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_LECUN_NORMAL()) ||
kerasInit.equals(conf.getINIT_LECUN_NORMAL_ALIAS())) {
init = WeightInit.LECUN_NORMAL;
return WeightInit.LECUN_NORMAL.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM()) ||
kerasInit.equals(conf.getINIT_LECUN_UNIFORM_ALIAS())) {
init = WeightInit.LECUN_UNIFORM;
return WeightInit.LECUN_UNIFORM.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_HE_NORMAL()) ||
kerasInit.equals(conf.getINIT_HE_NORMAL_ALIAS())) {
init = WeightInit.RELU;
return WeightInit.RELU.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_HE_UNIFORM()) ||
kerasInit.equals(conf.getINIT_HE_UNIFORM_ALIAS())) {
init = WeightInit.RELU_UNIFORM;
return WeightInit.RELU_UNIFORM.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_ONE()) ||
kerasInit.equals(conf.getINIT_ONES()) ||
kerasInit.equals(conf.getINIT_ONES_ALIAS())) {
init = WeightInit.ONES;
return WeightInit.ONES.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_ZERO()) ||
kerasInit.equals(conf.getINIT_ZEROS()) ||
kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) {
init = WeightInit.ZERO;
return WeightInit.ZERO.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
if (kerasMajorVersion == 2) {
double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL());
double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL());
dist = new UniformDistribution(minVal, maxVal);
return new WeightInitDistribution(new UniformDistribution(minVal, maxVal));
} else {
double scale = 0.05;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new UniformDistribution(-scale, scale);
return new WeightInitDistribution(new UniformDistribution(-scale, scale));
}
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_NORMAL()) ||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) ||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) {
if (kerasMajorVersion == 2) {
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
dist = new NormalDistribution(mean, stdDev);
return new WeightInitDistribution(new NormalDistribution(mean, stdDev));
} else {
double scale = 0.05;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new NormalDistribution(0, scale);
return new WeightInitDistribution(new NormalDistribution(0, scale));
}
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
double value = (double) initConfig.get(conf.getLAYER_FIELD_INIT_VALUE());
dist = new ConstantDistribution(value);
init = WeightInit.DISTRIBUTION;
return new WeightInitDistribution(new ConstantDistribution(value));
} else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) ||
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
if (kerasMajorVersion == 2) {
@ -121,34 +115,38 @@ public class KerasInitilizationUtils {
} catch (Exception e) {
gain = (int) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
}
dist = new OrthogonalDistribution(gain);
return new WeightInitDistribution(new OrthogonalDistribution(gain));
} else {
double scale = 1.1;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new OrthogonalDistribution(scale);
return new WeightInitDistribution(new OrthogonalDistribution(scale));
}
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) ||
kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
dist = new TruncatedNormalDistribution(mean, stdDev);
init = WeightInit.DISTRIBUTION;
return new WeightInitDistribution(new TruncatedNormalDistribution(mean, stdDev));
} else if (kerasInit.equals(conf.getINIT_IDENTITY()) ||
kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
if (kerasMajorVersion == 2) {
double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
if (gain != 1.)
log.warn("Scaled identity weight init not supported, setting gain=1");
if (gain != 1.0)
if (gain != 1.0) {
return new WeightInitIdentity(gain);
} else {
return new WeightInitIdentity();
}
} else {
double scale = 1.;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
if (scale != 1.)
log.warn("Scaled identity weight init not supported, setting scale=1");
if (scale != 1.0) {
return new WeightInitIdentity(scale);
} else {
return new WeightInitIdentity();
}
}
init = WeightInit.IDENTITY;
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
double scale;
try {
@ -156,32 +154,27 @@ public class KerasInitilizationUtils {
} catch (Exception e) {
scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
}
if (scale != 1.)
log.warn("Scaled identity weight init not supported, setting scale=1");
String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE());
String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION());
switch (mode) {
case "fan_in":
if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_IN;
return new WeightInitVarScalingNormalFanIn(scale);
} else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_IN;
return new WeightInitVarScalingUniformFanIn(scale);
}
break;
case "fan_out":
if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
return new WeightInitVarScalingNormalFanOut(scale);
} else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT;
return new WeightInitVarScalingUniformFanOut(scale);
}
break;
case "fan_avg":
if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
return new WeightInitVarScalingNormalFanAvg(scale);
} else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG;
return new WeightInitVarScalingUniformFanAvg(scale);
}
break;
default:
throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " +
"fan_in, fan_out or fan_avg");
@ -190,7 +183,7 @@ public class KerasInitilizationUtils {
throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
}
}
return new Pair<>(init, dist);
throw new IllegalStateException("Error getting Keras weight initialization");
}
/**
@ -202,7 +195,7 @@ public class KerasInitilizationUtils {
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
public static Pair<WeightInit, Distribution> getWeightInitFromConfig(Map<String, Object> layerConfig, String initField,
public static IWeightInit getWeightInitFromConfig(Map<String, Object> layerConfig, String initField,
boolean enforceTrainingConfig,
KerasLayerConfiguration conf,
int kerasMajorVersion)
@ -225,14 +218,14 @@ public class KerasInitilizationUtils {
throw new UnsupportedKerasConfigurationException("Incomplete initialization class");
}
}
Pair<WeightInit, Distribution> init;
IWeightInit init;
try {
init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion);
} catch (UnsupportedKerasConfigurationException e) {
if (enforceTrainingConfig)
throw e;
else {
init = new Pair<>(WeightInit.XAVIER, null);
init = new WeightInitXavier();
log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
}
}

View File

@ -21,7 +21,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;

View File

@ -16,7 +16,6 @@
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.nd4j.linalg.learning.regularization.L1Regularization;
@ -25,7 +24,6 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class KerasTestUtils {

View File

@ -22,8 +22,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.util.Nd4jValidator;
import org.nd4j.resources.Resources;
import org.nd4j.validation.ValidationResult;

View File

@ -21,7 +21,6 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
@ -30,7 +29,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;

View File

@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.InputStream;

View File

@ -30,11 +30,9 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;

View File

@ -25,6 +25,8 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitIdentity;
import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanIn;
import org.junit.Test;
import java.util.HashMap;
@ -94,11 +96,11 @@ public class KerasInitilizationTest extends BaseDL4JTest {
WeightInit.RELU_UNIFORM.getWeightInitFunction(),
WeightInit.ONES.getWeightInitFunction(),
WeightInit.ZERO.getWeightInitFunction(),
WeightInit.IDENTITY.getWeightInitFunction(),
new WeightInitIdentity(0.2),
WeightInit.DISTRIBUTION.getWeightInitFunction(new NormalDistribution(mean, stdDev)),
WeightInit.DISTRIBUTION.getWeightInitFunction(new OrthogonalDistribution(gain)),
WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(value)),
WeightInit.VAR_SCALING_NORMAL_FAN_IN.getWeightInitFunction()};
new WeightInitVarScalingNormalFanIn(0.2)};
}
private Distribution[] dl4jDistributions() {

View File

@ -17,22 +17,16 @@
package org.deeplearning4j.nn.modelimport.keras.configurations;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.File;
import java.io.IOException;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertNotNull;
/**

View File

@ -31,7 +31,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.File;

View File

@ -24,22 +24,19 @@ import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.modelimport.keras.*;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
@ -47,27 +44,25 @@ import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.resources.Resources;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Unit tests for end-to-end Keras model import.

View File

@ -21,7 +21,6 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
@ -31,11 +30,8 @@ import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
/**
* Import previously stored YOLO9000 Keras net from https://github.com/allanzelener/YAD2K.

View File

@ -26,7 +26,6 @@ import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.File;

View File

@ -27,16 +27,11 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousC
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* @author Max Pumperla

View File

@ -28,9 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.ArrayList;
import java.util.HashMap;
@ -39,7 +36,6 @@ import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* @author Max Pumperla

View File

@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D;
import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

View File

@ -16,13 +16,11 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D;
import org.junit.Test;

View File

@ -30,15 +30,11 @@ import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* @author Max Pumperla

View File

@ -17,18 +17,14 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.Upsampling1D;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;

View File

@ -17,13 +17,11 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D;
import org.junit.Test;
import java.util.ArrayList;

View File

@ -17,12 +17,10 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D;
import org.junit.Test;

View File

@ -26,16 +26,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* @author Max Pumperla

View File

@ -24,10 +24,12 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.junit.Test;
import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;

View File

@ -24,11 +24,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.util.*;

View File

@ -26,11 +26,7 @@ import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import static org.junit.Assert.assertEquals;

View File

@ -20,7 +20,6 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
@ -31,10 +30,8 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**

View File

@ -27,15 +27,14 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* @author Max Pumperla

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;

View File

@ -33,14 +33,13 @@ import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* @author Max Pumperla

View File

@ -16,15 +16,12 @@
package org.deeplearning4j.nn.modelimport.keras.optimizers;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.util.DL4JFileUtils;
import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.File;
@ -32,8 +29,6 @@ import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import static java.io.File.createTempFile;
public class OptimizerImport extends BaseDL4JTest {
@Test

View File

@ -18,9 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.IOException;

View File

@ -19,15 +19,11 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Import Keras Tokenizer

View File

@ -20,7 +20,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

View File

@ -29,7 +29,6 @@ import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.io.File;

View File

@ -16,13 +16,11 @@
package org.deeplearning4j.nn.conf.layers.samediff;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.*;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.autodiff.samediff.SDVariable;
@ -32,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
/**
@ -58,10 +57,12 @@ import java.util.Map;
public abstract class SameDiffLayer extends AbstractSameDiffLayer {
protected WeightInit weightInit;
protected Map<String,IWeightInit> paramWeightInit;
protected SameDiffLayer(Builder builder) {
super(builder);
this.weightInit = builder.weightInit;
this.paramWeightInit = builder.paramWeightInit;
}
protected SameDiffLayer() {
@ -115,6 +116,7 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
public static abstract class Builder<T extends Builder<T>> extends AbstractSameDiffLayer.Builder<T> {
protected WeightInit weightInit = WeightInit.XAVIER;
protected Map<String,IWeightInit> paramWeightInit;
/**
* @param weightInit Weight initialization to use for the layer
@ -123,5 +125,12 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
this.setWeightInit(weightInit);
return (T) this;
}
public T weightInit(@NonNull String param, @NonNull IWeightInit weightInit){
if(paramWeightInit == null)
paramWeightInit = new HashMap<>();
paramWeightInit.put(param, weightInit);
return (T) this;
}
}
}

View File

@ -16,11 +16,14 @@
package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays;
@ -32,9 +35,17 @@ import java.util.Arrays;
*
* @author Adam Gibson
*/
@EqualsAndHashCode
@Data
@NoArgsConstructor
public class WeightInitIdentity implements IWeightInit {
private Double scale;
public WeightInitIdentity(@JsonProperty("scale") Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
if (shape[0] != shape[1]) {
@ -59,6 +70,11 @@ public class WeightInitIdentity implements IWeightInit {
} else {
ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0]));
}
if(scale != null){
ret.muli(scale);
}
INDArray flat = Nd4j.toFlattened(order, ret);
paramView.assign(flat);
return paramView.reshape(order, shape);
@ -82,13 +98,16 @@ public class WeightInitIdentity implements IWeightInit {
indArrayIndices[i] = NDArrayIndex.point(shape[i] / 2);
}
paramView.assign(Nd4j.zeros(paramView.shape()));
paramView.assign(0);
final INDArray params =paramView.reshape(order, shape);
for (int i = 0; i < shape[0]; i++) {
indArrayIndices[0] = NDArrayIndex.point(i);
indArrayIndices[1] = NDArrayIndex.point(i);
params.put(indArrayIndices, Nd4j.ones(1));
}
if(scale != null){
params.muli(scale);
}
return params;
}
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.weights;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
import org.nd4j.linalg.factory.Nd4j;
@ -146,14 +147,13 @@ public class WeightInitUtil {
paramView.assign(flat);
break;
case VAR_SCALING_NORMAL_FAN_IN:
// TODO: needs to be truncated normal to match keras.
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanIn)));
break;
case VAR_SCALING_NORMAL_FAN_OUT:
Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut));
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanOut)));
break;
case VAR_SCALING_NORMAL_FAN_AVG:
Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2));
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(2.0 / (fanIn + fanOut))));
break;
case VAR_SCALING_UNIFORM_FAN_IN:
double scalingFanIn = 3.0 / Math.sqrt(fanIn);

View File

@ -16,22 +16,39 @@
package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.factory.Nd4j;
/**
* Gaussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2)
* Truncated aussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2)
*
* @author Adam Gibson
*/
@EqualsAndHashCode
@Data
@NoArgsConstructor
public class WeightInitVarScalingNormalFanAvg implements IWeightInit {
private Double scale;
public WeightInitVarScalingNormalFanAvg(Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2));
double std;
if(scale == null){
std = Math.sqrt(2.0 / (fanIn + fanOut));
} else {
std = Math.sqrt(2.0 * scale / (fanIn + fanOut));
}
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape);
}
}

View File

@ -16,23 +16,38 @@
package org.deeplearning4j.nn.weights;
import lombok.EqualsAndHashCode;
import org.apache.commons.math3.util.FastMath;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.factory.Nd4j;
/**
* Gaussian distribution with mean 0, variance 1.0/(fanIn)
* Gaussian distribution with mean 0, variance {@code 1.0/(fanIn)}<br>
* If a scale is provided, use variance {@code scale/(fanIn)} instead
*
* @author Adam Gibson
*/
@EqualsAndHashCode
@Data
@NoArgsConstructor
public class WeightInitVarScalingNormalFanIn implements IWeightInit {
private Double scale;
public WeightInitVarScalingNormalFanIn(Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
// TODO: needs to be truncated normal to match keras.
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
double std;
if(scale == null){
std = Math.sqrt(1.0 / fanIn);
} else {
std = Math.sqrt(scale / fanIn);
}
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape);
}
}

View File

@ -16,22 +16,40 @@
package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.factory.Nd4j;
/**
* Gaussian distribution with mean 0, variance 1.0/(fanOut)
* Truncated normal distribution with mean 0, variance 1.0/(fanOut)<br>
* If a scale is provided, variance is scale / fanOut
*
* @author Adam Gibson
*/
@EqualsAndHashCode
@Data
@NoArgsConstructor
public class WeightInitVarScalingNormalFanOut implements IWeightInit {
private Double scale;
public WeightInitVarScalingNormalFanOut(Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut));
double std;
if(scale == null){
std = Math.sqrt(1.0 / fanOut);
} else {
std = Math.sqrt(scale / fanOut);
}
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape);
}
}

View File

@ -16,7 +16,9 @@
package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -25,12 +27,22 @@ import org.nd4j.linalg.factory.Nd4j;
*
* @author Adam Gibson
*/
@EqualsAndHashCode
@Data
@NoArgsConstructor
public class WeightInitVarScalingUniformFanAvg implements IWeightInit {
private Double scale;
public WeightInitVarScalingUniformFanAvg(Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
if(scale != null)
scalingFanAvg *= scale;
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
return paramView.reshape(order, shape);
}

View File

@ -16,21 +16,34 @@
package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* Uniform U[-a,a] with a=3.0/(fanIn)
* Uniform U[-a,a] with a=3.0/(fanIn)<br>
* If a scale is provided, a = 3.0 * scale / (fanIn)
*
* @author Adam Gibson
*/
@EqualsAndHashCode
@NoArgsConstructor
@Data
public class WeightInitVarScalingUniformFanIn implements IWeightInit {
private Double scale;
public WeightInitVarScalingUniformFanIn(Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
double scalingFanIn = 3.0 / Math.sqrt(fanIn);
if(scale != null)
scalingFanIn *= scale;
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
return paramView.reshape(order, shape);
}

View File

@ -16,21 +16,33 @@
package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* Uniform U[-a,a] with a=3.0/(fanOut)
* Uniform U[-a,a] with a=3.0/(fanOut)<br>
* If a scale is provided, a = 3.0 * scale / fanOut
*
* @author Adam Gibson
*/
@EqualsAndHashCode
@Data
@NoArgsConstructor
public class WeightInitVarScalingUniformFanOut implements IWeightInit {
private Double scale;
public WeightInitVarScalingUniformFanOut(Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
double scalingFanOut = 3.0 / Math.sqrt(fanOut);
if(scale != null)
scalingFanOut *= scale;
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
return paramView.reshape(order, shape);
}

View File

@ -25,7 +25,7 @@ elseif (APPLE)
elseif(WIN32)
set(X86_BUILD true)
if (CUDA_BLAS)
set(CMAKE_CXX_FLAGS_RELEASE " /O2 -D_RELEASE=true /wd4804")
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")

View File

@ -3607,6 +3607,13 @@ public class Shape {
return ArrayUtil.prodLong(shape);
}
public static long lengthOf(int[] shape) {
if (shape.length == 0)
return 1L;
else
return ArrayUtil.prodLong(shape);
}
/**
* Calculate the length of the buffer required to store the given shape with the given strides
*

View File

@ -28,11 +28,6 @@ import java.util.List;
* @author Adam Gibson
*/
public class SamplingDataSetIterator implements DataSetIterator {
/**
*
*/
private static final long serialVersionUID = -2700563801361726914L;
private DataSet sampleFrom;
private int batchSize;
private int totalNumberSamples;
@ -145,6 +140,4 @@ public class SamplingDataSetIterator implements DataSetIterator {
numTimesSampled++;
return ret;
}
}

View File

@ -1164,26 +1164,15 @@ public class Nd4j {
* @param type the opType to create
* @return the created buffer
*/
public static DataBuffer createBuffer(int[] shape, DataType type) {
long length = ArrayUtil.prodLong(shape);
if (type == DataType.INT)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (type == DataType.LONG)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (type == DataType.HALF)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (type == DataType.DOUBLE)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
public static DataBuffer createBuffer(@NonNull int[] shape, @NonNull DataType type) {
return createBuffer(ArrayUtil.toLongArray(shape), type);
}
/**
* See {@link #createBuffer(int[], DataType)}
*/
public static DataBuffer createBuffer(long[] shape, DataType type) {
long length = ArrayUtil.prodLong(shape);
public static DataBuffer createBuffer(@NonNull long[] shape, @NonNull DataType type) {
long length = Shape.lengthOf(shape);
switch (type) {
case BOOL:
@ -1229,14 +1218,14 @@ public class Nd4j {
* @return the created buffer.
*/
public static DataBuffer createBufferDetached(int[] shape, DataType type) {
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
return createBufferDetachedImpl( Shape.lengthOf(shape), type);
}
/**
* See {@link #createBufferDetached(int[], DataType)}
*/
public static DataBuffer createBufferDetached(long[] shape, DataType type) {
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
return createBufferDetachedImpl( Shape.lengthOf(shape), type);
}
// used by createBufferDetached(long[] DataType) and createBufferDetached(int[] , DataType)