Fixes and pre-release QA ()

*  Keras import - support scaled identity weight init

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More Keras scaled weight init fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

*  Deprecate duplicate SamplingDataSetIterator class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove /O2 optimization for faster CUDA build

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweak regression test precision for CUDA

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix edge cases for buffer creation

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update MKLDNN validation tests to new helper enable/disable settings

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Delete debugging class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* MKLDNN test - add proper skip for CUDA backend

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Align WeightInitUtil with weight init classes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for SameDiff test layers weight init when using IWeightInit classes

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-16 17:04:29 +11:00 committed by GitHub
parent 1780dcc883
commit 09a827fb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
85 changed files with 378 additions and 574 deletions
deeplearning4j
deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator
deeplearning4j-modelimport/src
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg
api/shape
dataset/api/iterator
factory

View File

@ -35,6 +35,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*; import java.util.*;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -63,6 +64,30 @@ public class LayerHelperValidationUtil {
private DataSetIterator data; private DataSetIterator data;
} }
public static void disableCppHelpers(){
try {
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method m = c.getMethod("getInstance");
Object instance = m.invoke(null);
Method m2 = c.getMethod("allowHelpers", boolean.class);
m2.invoke(instance, false);
} catch (Throwable t){
throw new RuntimeException(t);
}
}
public static void enableCppHelpers(){
try{
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
Method m = c.getMethod("getInstance");
Object instance = m.invoke(null);
Method m2 = c.getMethod("allowHelpers", boolean.class);
m2.invoke(instance, true);
} catch (Throwable t){
throw new RuntimeException(t);
}
}
public static void validateMLN(MultiLayerNetwork netOrig, TestCase t){ public static void validateMLN(MultiLayerNetwork netOrig, TestCase t){
assertNotNull(t.getAllowHelpersForClasses()); assertNotNull(t.getAllowHelpersForClasses());
assertFalse(t.getAllowHelpersForClasses().isEmpty()); assertFalse(t.getAllowHelpersForClasses().isEmpty());
@ -95,7 +120,13 @@ public class LayerHelperValidationUtil {
for (boolean train : new boolean[]{false, true}) { for (boolean train : new boolean[]{false, true}) {
assertEquals(net1NoHelper.params(), net2With.params()); assertEquals(net1NoHelper.params(), net2With.params());
String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: "); String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
List<INDArray> ff1 = net1NoHelper.feedForward(t.getFeatures(), train); List<INDArray> ff1;
try {
disableCppHelpers();
ff1 = net1NoHelper.feedForward(t.getFeatures(), train);
} finally {
enableCppHelpers();
}
List<INDArray> ff2 = net2With.feedForward(t.getFeatures(), train); List<INDArray> ff2 = net2With.feedForward(t.getFeatures(), train);
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet()); List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());
Collections.sort(paramKeys); Collections.sort(paramKeys);
@ -131,7 +162,13 @@ public class LayerHelperValidationUtil {
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE); log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
} }
INDArray out1 = net1NoHelper.output(t.getFeatures(), train); INDArray out1;
try {
disableCppHelpers();
out1 = net1NoHelper.output(t.getFeatures(), train);
} finally {
enableCppHelpers();
}
INDArray out2 = net2With.output(t.getFeatures(), train); INDArray out2 = net2With.output(t.getFeatures(), train);
INDArray relError = relError(out1, out2, t.getMinAbsError()); INDArray relError = relError(out1, out2, t.getMinAbsError());
double maxRE = relError.maxNumber().doubleValue(); double maxRE = relError.maxNumber().doubleValue();
@ -148,7 +185,13 @@ public class LayerHelperValidationUtil {
Preconditions.checkNotNull(t.getLabels(), "Labels are not set (null)"); Preconditions.checkNotNull(t.getLabels(), "Labels are not set (null)");
log.info("Validation - checking scores"); log.info("Validation - checking scores");
double s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels())); double s1;
try {
disableCppHelpers();
s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels()));
} finally {
enableCppHelpers();
}
double s2 = net2With.score(new DataSet(t.getFeatures(), t.getLabels())); double s2 = net2With.score(new DataSet(t.getFeatures(), t.getLabels()));
double re = relError(s1, s2); double re = relError(s1, s2);
@ -168,7 +211,12 @@ public class LayerHelperValidationUtil {
net2With.setInput(t.getFeatures()); net2With.setInput(t.getFeatures());
net2With.setLabels(t.getLabels()); net2With.setLabels(t.getLabels());
net1NoHelper.computeGradientAndScore(); try {
disableCppHelpers();
net1NoHelper.computeGradientAndScore();
} finally {
enableCppHelpers();
}
net2With.computeGradientAndScore(); net2With.computeGradientAndScore();
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet()); List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());

View File

@ -1,107 +0,0 @@
package org.deeplearning4j;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.lang.reflect.Field;
import static junit.framework.TestCase.*;
public class TestBatchNormBp {
@Test
public void test(){
Nd4j.getRandom().setSeed(12345);
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
double e = 1e-5;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(in, mean, var, eps, gamma, beta)
.addIntegerArguments(
1, //Apply scale
1, //Apply beta
1) //Axis (NCHW)
.addFloatingPointArguments(e)
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
.build();
Nd4j.exec(op);
System.out.println(dLdIn);
}
@Test
public void compareImpls() throws Exception {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
INDArray var = in.var(0, 2, 3).reshape(1,3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
double e = 1e-3;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.inferenceWorkspaceMode(WorkspaceMode.NONE)
.trainingWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
assertNotNull(bn.getHelper());
Field f = bn.getClass().getDeclaredField("helper");
f.setAccessible(true);
f.set(bn, null);
assertNull(bn.getHelper());
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
INDArray dldin_dl4j = p.getSecond();
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
}
}

View File

@ -70,8 +70,19 @@ public class MinimalSameDiffDense extends SameDiffLayer {
@Override @Override
public void initializeParameters(Map<String, INDArray> params) { public void initializeParameters(Map<String, INDArray> params) {
params.get(DefaultParamInitializer.BIAS_KEY).assign(0); String b = DefaultParamInitializer.BIAS_KEY;
initWeights(nIn, nOut, weightInit, params.get(DefaultParamInitializer.WEIGHT_KEY)); if(paramWeightInit != null && paramWeightInit.containsKey(b)){
paramWeightInit.get(b).init(nIn, nOut, params.get(b).shape(), 'c', params.get(b));
} else {
params.get(DefaultParamInitializer.BIAS_KEY).assign(0);
}
String w = DefaultParamInitializer.WEIGHT_KEY;
if(paramWeightInit != null && paramWeightInit.containsKey(w)){
paramWeightInit.get(w).init(nIn, nOut, params.get(w).shape(), 'c', params.get(w));
} else {
initWeights(nIn, nOut, weightInit, params.get(DefaultParamInitializer.WEIGHT_KEY));
}
} }
//OPTIONAL methods: //OPTIONAL methods:

View File

@ -109,13 +109,17 @@ public class SameDiffConv extends SameDiffLayer {
@Override @Override
public void initializeParameters(Map<String, INDArray> params) { public void initializeParameters(Map<String, INDArray> params) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
double fanIn = nIn * kernel[0] * kernel[1];
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
for (Map.Entry<String, INDArray> e : params.entrySet()) { for (Map.Entry<String, INDArray> e : params.entrySet()) {
if (ConvolutionParamInitializer.BIAS_KEY.equals(e.getKey())) { if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
e.getValue().assign(0); paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
} else { } else {
double fanIn = nIn * kernel[0] * kernel[1]; if (ConvolutionParamInitializer.BIAS_KEY.equals(e.getKey())) {
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]); e.getValue().assign(0);
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue()); } else {
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
}
} }
} }
} }

View File

@ -88,11 +88,15 @@ public class SameDiffDense extends SameDiffLayer {
@Override @Override
public void initializeParameters(Map<String,INDArray> params){ public void initializeParameters(Map<String,INDArray> params){
for(Map.Entry<String,INDArray> e : params.entrySet()){ for(Map.Entry<String,INDArray> e : params.entrySet()){
if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){ if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
e.getValue().assign(0.0); paramWeightInit.get(e.getKey()).init(nIn, nOut, e.getValue().shape(), 'c', e.getValue());
} else { } else {
//Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){
WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', e.getValue()); e.getValue().assign(0.0);
} else {
//Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer
WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', e.getValue());
}
} }
} }
} }

View File

@ -50,6 +50,7 @@ import static org.junit.Assume.assumeTrue;
public class ValidateMKLDNN extends BaseDL4JTest { public class ValidateMKLDNN extends BaseDL4JTest {
@Test @Test
public void validateConvSubsampling() throws Exception { public void validateConvSubsampling() throws Exception {
//Only run test if using nd4j-native backend //Only run test if using nd4j-native backend
@ -268,6 +269,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
@Test @Test
public void compareBatchNormBackward() throws Exception { public void compareBatchNormBackward() throws Exception {
assumeTrue(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native"));
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15); INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);

View File

@ -339,7 +339,13 @@ public class RegressionTest100b4 extends BaseDL4JTest {
INDArray outAct = net.output(in); INDArray outAct = net.output(in);
assertEquals(outExp, outAct); //19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
assertEquals(outExp, outAct);
} else {
boolean eq = outExp.equalsWithEps(outAct, 0.1);
assertTrue(eq);
}
} }
@Test @Test

View File

@ -24,101 +24,11 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List; import java.util.List;
/** /**
* A wrapper for a dataset to sample from. * @deprecated Use {@link org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator}
* This will randomly sample from the given dataset.
* @author Adam GIbson
*/ */
public class SamplingDataSetIterator implements DataSetIterator { @Deprecated
public class SamplingDataSetIterator extends org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator {
/**
*
*/
private static final long serialVersionUID = -2700563801361726914L;
private DataSet sampleFrom;
private int batchSize;
private int totalNumberSamples;
private int numTimesSampled;
@Getter
private DataSetPreProcessor preProcessor;
/**
*
* @param sampleFrom the dataset to sample from
* @param batchSize the batch size to sample
* @param totalNumberSamples the sample size
*/
public SamplingDataSetIterator(DataSet sampleFrom, int batchSize, int totalNumberSamples) { public SamplingDataSetIterator(DataSet sampleFrom, int batchSize, int totalNumberSamples) {
super(); super(sampleFrom, batchSize, totalNumberSamples);
this.sampleFrom = sampleFrom;
this.batchSize = batchSize;
this.totalNumberSamples = totalNumberSamples;
} }
@Override
public boolean hasNext() {
return numTimesSampled < totalNumberSamples;
}
@Override
public DataSet next() {
DataSet ret = sampleFrom.sample(batchSize);
numTimesSampled += batchSize;
return ret;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
@Override
public int inputColumns() {
return sampleFrom.numInputs();
}
@Override
public int totalOutcomes() {
return sampleFrom.numOutcomes();
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public void reset() {
numTimesSampled = 0;
}
@Override
public int batch() {
return batchSize;
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
}
@Override
public List<String> getLabels() {
return null;
}
@Override
public DataSet next(int num) {
DataSet ret = sampleFrom.sample(num);
numTimesSampled++;
return ret;
}
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras; package org.deeplearning4j.nn.modelimport.keras;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.bytedeco.hdf5.*;
import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Loader;
@ -32,7 +33,6 @@ import java.lang.Exception;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.bytedeco.hdf5.*;
import static org.bytedeco.hdf5.global.hdf5.*; import static org.bytedeco.hdf5.global.hdf5.*;
/** /**

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.nn.modelimport.keras; package org.deeplearning4j.nn.modelimport.keras;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.PReLULayer; import org.deeplearning4j.nn.conf.layers.PReLULayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -27,9 +26,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.PReLUParamInitializer; import org.deeplearning4j.nn.params.PReLUParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import java.util.HashMap; import java.util.HashMap;
@ -79,14 +77,12 @@ public class KerasPReLU extends KerasLayer {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, ALPHA_CONSTRAINT, conf, kerasMajorVersion); layerConfig, ALPHA_CONSTRAINT, conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, ALPHA_INIT, IWeightInit init = getWeightInitFromConfig(layerConfig, ALPHA_INIT,
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
long[] axes = getSharedAxes(layerConfig); long[] axes = getSharedAxes(layerConfig);
PReLULayer.Builder builder = new PReLULayer.Builder().sharedAxes(axes) PReLULayer.Builder builder = new PReLULayer.Builder().sharedAxes(axes)
.weightInit(weightInit.getWeightInitFunction(distribution)).name(layerName); .weightInit(init).name(layerName);
if (weightConstraint != null){ if (weightConstraint != null){
builder.constrainWeights(weightConstraint); builder.constrainWeights(weightConstraint);
} }

View File

@ -17,14 +17,12 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.primitives.Pair;
import java.util.Map; import java.util.Map;
@ -83,15 +81,13 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.dilation(getDilationRate(layerConfig, 1, conf, true)[0]) .dilation(getDilationRate(layerConfig, 1, conf, true)[0])
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -17,14 +17,12 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.primitives.Pair;
import java.util.Map; import java.util.Map;
@ -84,14 +82,13 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction()) .weightInit(init)
.dilation(getDilationRate(layerConfig, 2, conf, true)) .dilation(getDilationRate(layerConfig, 2, conf, true))
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
@ -30,7 +29,6 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights; import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;

View File

@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
@ -30,10 +29,9 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -94,15 +92,13 @@ public class KerasConvolution1D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])

View File

@ -21,14 +21,12 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.primitives.Pair;
import java.util.Map; import java.util.Map;
@ -87,10 +85,8 @@ public class KerasConvolution2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1; numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -100,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))

View File

@ -21,15 +21,13 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.primitives.Pair;
import java.util.Map; import java.util.Map;
@ -88,10 +86,8 @@ public class KerasConvolution3D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1; numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 3, conf, false); int[] dilationRate = getDilationRate(layerConfig, 3, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -101,7 +97,7 @@ public class KerasConvolution3D extends KerasConvolution {
Convolution3D.Builder builder = new Convolution3D.Builder().name(this.layerName) Convolution3D.Builder builder = new Convolution3D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 3, conf, kerasMajorVersion)) .kernelSize(getKernelSizeFromConfig(layerConfig, 3, conf, kerasMajorVersion))

View File

@ -20,14 +20,12 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D; import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.primitives.Pair;
import java.util.Map; import java.util.Map;
@ -86,10 +84,8 @@ public class KerasDeconvolution2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1; numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -99,7 +95,7 @@ public class KerasDeconvolution2D extends KerasConvolution {
Deconvolution2D.Builder builder = new Deconvolution2D.Builder().name(this.layerName) Deconvolution2D.Builder builder = new Deconvolution2D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D; import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -30,9 +29,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer; import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -126,10 +124,8 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1; numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig, IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit depthWeightInit = depthWiseInit.getFirst();
Distribution depthDistribution = depthWiseInit.getSecond();
val nIn = getNInFromConfig(previousLayers); val nIn = getNInFromConfig(previousLayers);
@ -152,7 +148,7 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
.nIn(nIn) .nIn(nIn)
.nOut(nIn * depthMultiplier) .nOut(nIn * depthMultiplier)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution)) .weightInit(depthWiseInit)
.depthMultiplier(depthMultiplier) .depthMultiplier(depthMultiplier)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -20,7 +20,6 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
@ -28,9 +27,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer; import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -93,17 +91,13 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
int depthMultiplier = getDepthMultiplier(layerConfig, conf); int depthMultiplier = getDepthMultiplier(layerConfig, conf);
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig, IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit depthWeightInit = depthWiseInit.getFirst();
Distribution depthDistribution = depthWiseInit.getSecond();
Pair<WeightInit, Distribution> pointWiseInit = getWeightInitFromConfig(layerConfig, IWeightInit pointWiseInit = getWeightInitFromConfig(layerConfig,
conf.getLAYER_FIELD_POINT_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); conf.getLAYER_FIELD_POINT_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit pointWeightInit = pointWiseInit.getFirst();
Distribution pointDistribution = pointWiseInit.getSecond();
if (depthWeightInit != pointWeightInit || depthDistribution != pointDistribution) if ( !depthWiseInit.getClass().equals(pointWiseInit.getClass()) )
if (enforceTrainingConfig) if (enforceTrainingConfig)
throw new UnsupportedKerasConfigurationException( throw new UnsupportedKerasConfigurationException(
"Specifying different initialization for depth- and point-wise weights not supported."); "Specifying different initialization for depth- and point-wise weights not supported.");
@ -126,7 +120,7 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
SeparableConvolution2D.Builder builder = new SeparableConvolution2D.Builder().name(this.layerName) SeparableConvolution2D.Builder builder = new SeparableConvolution2D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution)) .weightInit(depthWiseInit)
.depthMultiplier(depthMultiplier) .depthMultiplier(depthMultiplier)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.Upsampling3D; import org.deeplearning4j.nn.conf.layers.Upsampling3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer; import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -95,15 +93,13 @@ public class KerasDense extends KerasLayer {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName) DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)) .nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout).activation(getIActivationFromConfig(layerConfig, conf)) .dropOut(this.dropout).activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.hasBias(hasBias); .hasBias(hasBias);

View File

@ -22,7 +22,6 @@ import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional; import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -26,7 +25,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -30,11 +29,10 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -106,10 +104,8 @@ public class KerasEmbedding extends KerasLayer {
"in DL4J, apply masking as a pre-processing step to your input." + "in DL4J, apply masking as a pre-processing step to your input." +
"See http://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent#masking for more on this."); "See http://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent#masking for more on this.");
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), conf, kerasMajorVersion);
@ -121,7 +117,7 @@ public class KerasEmbedding extends KerasLayer {
.inferInputLength(inferInputLength) .inferInputLength(inferInputLength)
.nOut(getNOutFromConfig(layerConfig, conf)) .nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout).activation(Activation.IDENTITY) .dropOut(this.dropout).activation(Activation.IDENTITY)
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization) .l2(this.weightL2Regularization)

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D; import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -90,11 +88,8 @@ public class KerasLocallyConnected1D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1; numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 1, conf, false); int[] dilationRate = getDilationRate(layerConfig, 1, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
// TODO: take care of distribution and bias init
//Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -104,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution {
LocallyConnected1D.Builder builder = new LocallyConnected1D.Builder().name(this.layerName) LocallyConnected1D.Builder builder = new LocallyConnected1D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getActivationFromConfig(layerConfig, conf)) .activation(getActivationFromConfig(layerConfig, conf))
.weightInit(weightInit) .weightInit(conf.getKERAS_PARAM_NAME_W(), init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D; import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -39,9 +37,7 @@ import java.util.Map;
import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.*; import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.*;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils.getActivationFromConfig; import static org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils.getActivationFromConfig;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils.getWeightInitFromConfig; import static org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils.getWeightInitFromConfig;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getHasBiasFromConfig; import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.*;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getNOutFromConfig;
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;
/** /**
@ -92,11 +88,9 @@ public class KerasLocallyConnected2D extends KerasConvolution {
numTrainableParams = hasBias ? 2 : 1; numTrainableParams = hasBias ? 2 : 1;
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst(); // TODO: take care of bias init
// TODO: take care of distribution and bias init
//Distribution distribution = init.getSecond();
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
@ -106,7 +100,7 @@ public class KerasLocallyConnected2D extends KerasConvolution {
LocallyConnected2D.Builder builder = new LocallyConnected2D.Builder().name(this.layerName) LocallyConnected2D.Builder builder = new LocallyConnected2D.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getActivationFromConfig(layerConfig, conf)) .activation(getActivationFromConfig(layerConfig, conf))
.weightInit(weightInit) .weightInit(conf.getKERAS_PARAM_NAME_W(), init)
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))

View File

@ -31,7 +31,6 @@ import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;

View File

@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -151,15 +150,11 @@ public class KerasLSTM extends KerasLayer {
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig); super(layerConfig, enforceTrainingConfig);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit recurrentWeightInit = recurrentInit.getFirst();
Distribution recurrentDistribution = recurrentInit.getSecond();
boolean hasBias = getHasBiasFromConfig(layerConfig, conf); boolean hasBias = getHasBiasFromConfig(layerConfig, conf);
@ -186,8 +181,8 @@ public class KerasLSTM extends KerasLayer {
.nOut(getNOutFromConfig(layerConfig, conf)) .nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout) .dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution)) .weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect .biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization);

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.Layer;
@ -34,7 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -124,15 +123,11 @@ public class KerasSimpleRnn extends KerasLayer {
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig); super(layerConfig, enforceTrainingConfig);
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit weightInit = init.getFirst();
Distribution distribution = init.getSecond();
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
WeightInit recurrentWeightInit = recurrentInit.getFirst();
Distribution recurrentDistribution = recurrentInit.getSecond();
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
this.returnSequences = (Boolean) innerConfig.get(conf.getLAYER_FIELD_RETURN_SEQUENCES()); this.returnSequences = (Boolean) innerConfig.get(conf.getLAYER_FIELD_RETURN_SEQUENCES());
@ -154,8 +149,8 @@ public class KerasSimpleRnn extends KerasLayer {
.nOut(getNOutFromConfig(layerConfig, conf)) .nOut(getNOutFromConfig(layerConfig, conf))
.dropOut(this.dropout) .dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(weightInit.getWeightInitFunction(distribution)) .weightInit(init)
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution)) .weightInitRecurrent(recurrentInit)
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization);

View File

@ -20,9 +20,7 @@ import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken; import com.google.gson.reflect.TypeToken;
import lombok.Data; import lombok.Data;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -31,7 +29,6 @@ import org.nd4j.linalg.primitives.Pair;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;

View File

@ -22,9 +22,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
/** /**

View File

@ -19,17 +19,15 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessors;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -20,9 +20,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.shade.jackson.annotation.JsonCreator; import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -1,28 +1,15 @@
package org.deeplearning4j.nn.modelimport.keras.utils; package org.deeplearning4j.nn.modelimport.keras.utils;
import lombok.NonNull; import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.validation.Nd4jCommonValidator; import org.nd4j.validation.Nd4jCommonValidator;
import org.nd4j.validation.ValidationResult; import org.nd4j.validation.ValidationResult;
import java.io.BufferedReader;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
/** /**
* A utility for validating serialized Keras sequential and functional models for import into DL4J * A utility for validating serialized Keras sequential and functional models for import into DL4J

View File

@ -21,7 +21,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.*;
import java.util.Map; import java.util.Map;

View File

@ -21,8 +21,7 @@ import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.*;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -42,76 +41,71 @@ public class KerasInitilizationUtils {
* @return DL4J weight initialization enum * @return DL4J weight initialization enum
* @see WeightInit * @see WeightInit
*/ */
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit, public static IWeightInit mapWeightInitialization(String kerasInit,
KerasLayerConfiguration conf, KerasLayerConfiguration conf,
Map<String, Object> initConfig, Map<String, Object> initConfig,
int kerasMajorVersion) int kerasMajorVersion)
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException { throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
// TODO: Identity and VarianceScaling need "scale" factor // TODO: Identity and VarianceScaling need "scale" factor
WeightInit init = null;
Distribution dist = null;
if (kerasInit != null) { if (kerasInit != null) {
if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL()) || if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL()) ||
kerasInit.equals(conf.getINIT_GLOROT_NORMAL_ALIAS())) { kerasInit.equals(conf.getINIT_GLOROT_NORMAL_ALIAS())) {
init = WeightInit.XAVIER; return WeightInit.XAVIER.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM()) || } else if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM()) ||
kerasInit.equals(conf.getINIT_GLOROT_UNIFORM_ALIAS())) { kerasInit.equals(conf.getINIT_GLOROT_UNIFORM_ALIAS())) {
init = WeightInit.XAVIER_UNIFORM; return WeightInit.XAVIER_UNIFORM.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_LECUN_NORMAL()) || } else if (kerasInit.equals(conf.getINIT_LECUN_NORMAL()) ||
kerasInit.equals(conf.getINIT_LECUN_NORMAL_ALIAS())) { kerasInit.equals(conf.getINIT_LECUN_NORMAL_ALIAS())) {
init = WeightInit.LECUN_NORMAL; return WeightInit.LECUN_NORMAL.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM()) || } else if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM()) ||
kerasInit.equals(conf.getINIT_LECUN_UNIFORM_ALIAS())) { kerasInit.equals(conf.getINIT_LECUN_UNIFORM_ALIAS())) {
init = WeightInit.LECUN_UNIFORM; return WeightInit.LECUN_UNIFORM.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_HE_NORMAL()) || } else if (kerasInit.equals(conf.getINIT_HE_NORMAL()) ||
kerasInit.equals(conf.getINIT_HE_NORMAL_ALIAS())) { kerasInit.equals(conf.getINIT_HE_NORMAL_ALIAS())) {
init = WeightInit.RELU; return WeightInit.RELU.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_HE_UNIFORM()) || } else if (kerasInit.equals(conf.getINIT_HE_UNIFORM()) ||
kerasInit.equals(conf.getINIT_HE_UNIFORM_ALIAS())) { kerasInit.equals(conf.getINIT_HE_UNIFORM_ALIAS())) {
init = WeightInit.RELU_UNIFORM; return WeightInit.RELU_UNIFORM.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_ONE()) || } else if (kerasInit.equals(conf.getINIT_ONE()) ||
kerasInit.equals(conf.getINIT_ONES()) || kerasInit.equals(conf.getINIT_ONES()) ||
kerasInit.equals(conf.getINIT_ONES_ALIAS())) { kerasInit.equals(conf.getINIT_ONES_ALIAS())) {
init = WeightInit.ONES; return WeightInit.ONES.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_ZERO()) || } else if (kerasInit.equals(conf.getINIT_ZERO()) ||
kerasInit.equals(conf.getINIT_ZEROS()) || kerasInit.equals(conf.getINIT_ZEROS()) ||
kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) { kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) {
init = WeightInit.ZERO; return WeightInit.ZERO.getWeightInitFunction();
} else if (kerasInit.equals(conf.getINIT_UNIFORM()) || } else if (kerasInit.equals(conf.getINIT_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) || kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) ||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) { kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
if (kerasMajorVersion == 2) { if (kerasMajorVersion == 2) {
double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL()); double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL());
double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL()); double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL());
dist = new UniformDistribution(minVal, maxVal); return new WeightInitDistribution(new UniformDistribution(minVal, maxVal));
} else { } else {
double scale = 0.05; double scale = 0.05;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new UniformDistribution(-scale, scale); return new WeightInitDistribution(new UniformDistribution(-scale, scale));
} }
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_NORMAL()) || } else if (kerasInit.equals(conf.getINIT_NORMAL()) ||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) || kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) ||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) { kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) {
if (kerasMajorVersion == 2) { if (kerasMajorVersion == 2) {
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN()); double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV()); double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
dist = new NormalDistribution(mean, stdDev); return new WeightInitDistribution(new NormalDistribution(mean, stdDev));
} else { } else {
double scale = 0.05; double scale = 0.05;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new NormalDistribution(0, scale); return new WeightInitDistribution(new NormalDistribution(0, scale));
} }
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) || } else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) { kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
double value = (double) initConfig.get(conf.getLAYER_FIELD_INIT_VALUE()); double value = (double) initConfig.get(conf.getLAYER_FIELD_INIT_VALUE());
dist = new ConstantDistribution(value); return new WeightInitDistribution(new ConstantDistribution(value));
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) || } else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) ||
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) { kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
if (kerasMajorVersion == 2) { if (kerasMajorVersion == 2) {
@ -121,34 +115,38 @@ public class KerasInitilizationUtils {
} catch (Exception e) { } catch (Exception e) {
gain = (int) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN()); gain = (int) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
} }
dist = new OrthogonalDistribution(gain); return new WeightInitDistribution(new OrthogonalDistribution(gain));
} else { } else {
double scale = 1.1; double scale = 1.1;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
dist = new OrthogonalDistribution(scale); return new WeightInitDistribution(new OrthogonalDistribution(scale));
} }
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) || } else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) ||
kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) { kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN()); double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV()); double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
dist = new TruncatedNormalDistribution(mean, stdDev); return new WeightInitDistribution(new TruncatedNormalDistribution(mean, stdDev));
init = WeightInit.DISTRIBUTION;
} else if (kerasInit.equals(conf.getINIT_IDENTITY()) || } else if (kerasInit.equals(conf.getINIT_IDENTITY()) ||
kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) { kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
if (kerasMajorVersion == 2) { if (kerasMajorVersion == 2) {
double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN()); double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
if (gain != 1.) if (gain != 1.0)
log.warn("Scaled identity weight init not supported, setting gain=1"); if (gain != 1.0) {
return new WeightInitIdentity(gain);
} else {
return new WeightInitIdentity();
}
} else { } else {
double scale = 1.; double scale = 1.;
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
if (scale != 1.) if (scale != 1.0) {
log.warn("Scaled identity weight init not supported, setting scale=1"); return new WeightInitIdentity(scale);
} else {
return new WeightInitIdentity();
}
} }
init = WeightInit.IDENTITY;
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) { } else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
double scale; double scale;
try { try {
@ -156,32 +154,27 @@ public class KerasInitilizationUtils {
} catch (Exception e) { } catch (Exception e) {
scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
} }
if (scale != 1.)
log.warn("Scaled identity weight init not supported, setting scale=1");
String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE()); String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE());
String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION()); String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION());
switch (mode) { switch (mode) {
case "fan_in": case "fan_in":
if (distribution.equals("normal")) { if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_IN; return new WeightInitVarScalingNormalFanIn(scale);
} else { } else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_IN; return new WeightInitVarScalingUniformFanIn(scale);
} }
break;
case "fan_out": case "fan_out":
if (distribution.equals("normal")) { if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_OUT; return new WeightInitVarScalingNormalFanOut(scale);
} else { } else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT; return new WeightInitVarScalingUniformFanOut(scale);
} }
break;
case "fan_avg": case "fan_avg":
if (distribution.equals("normal")) { if (distribution.equals("normal")) {
init = WeightInit.VAR_SCALING_NORMAL_FAN_AVG; return new WeightInitVarScalingNormalFanAvg(scale);
} else { } else {
init = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG; return new WeightInitVarScalingUniformFanAvg(scale);
} }
break;
default: default:
throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " + throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " +
"fan_in, fan_out or fan_avg"); "fan_in, fan_out or fan_avg");
@ -190,7 +183,7 @@ public class KerasInitilizationUtils {
throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit); throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
} }
} }
return new Pair<>(init, dist); throw new IllegalStateException("Error getting Keras weight initialization");
} }
/** /**
@ -202,7 +195,7 @@ public class KerasInitilizationUtils {
* @throws InvalidKerasConfigurationException Invalid Keras config * @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config * @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/ */
public static Pair<WeightInit, Distribution> getWeightInitFromConfig(Map<String, Object> layerConfig, String initField, public static IWeightInit getWeightInitFromConfig(Map<String, Object> layerConfig, String initField,
boolean enforceTrainingConfig, boolean enforceTrainingConfig,
KerasLayerConfiguration conf, KerasLayerConfiguration conf,
int kerasMajorVersion) int kerasMajorVersion)
@ -225,14 +218,14 @@ public class KerasInitilizationUtils {
throw new UnsupportedKerasConfigurationException("Incomplete initialization class"); throw new UnsupportedKerasConfigurationException("Incomplete initialization class");
} }
} }
Pair<WeightInit, Distribution> init; IWeightInit init;
try { try {
init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion); init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion);
} catch (UnsupportedKerasConfigurationException e) { } catch (UnsupportedKerasConfigurationException e) {
if (enforceTrainingConfig) if (enforceTrainingConfig)
throw e; throw e;
else { else {
init = new Pair<>(WeightInit.XAVIER, null); init = new WeightInitXavier();
log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead)."); log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
} }
} }

View File

@ -21,7 +21,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;

View File

@ -16,7 +16,6 @@
package org.deeplearning4j.nn.modelimport.keras; package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.L1Regularization;
@ -25,7 +24,6 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
public class KerasTestUtils { public class KerasTestUtils {

View File

@ -22,8 +22,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.util.Nd4jValidator;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import org.nd4j.validation.ValidationResult; import org.nd4j.validation.ValidationResult;

View File

@ -21,7 +21,6 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
@ -30,7 +29,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Assert;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;

View File

@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.InputStream; import java.io.InputStream;

View File

@ -30,11 +30,9 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Arrays; import java.util.Arrays;

View File

@ -25,6 +25,8 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense; import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitIdentity;
import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanIn;
import org.junit.Test; import org.junit.Test;
import java.util.HashMap; import java.util.HashMap;
@ -94,11 +96,11 @@ public class KerasInitilizationTest extends BaseDL4JTest {
WeightInit.RELU_UNIFORM.getWeightInitFunction(), WeightInit.RELU_UNIFORM.getWeightInitFunction(),
WeightInit.ONES.getWeightInitFunction(), WeightInit.ONES.getWeightInitFunction(),
WeightInit.ZERO.getWeightInitFunction(), WeightInit.ZERO.getWeightInitFunction(),
WeightInit.IDENTITY.getWeightInitFunction(), new WeightInitIdentity(0.2),
WeightInit.DISTRIBUTION.getWeightInitFunction(new NormalDistribution(mean, stdDev)), WeightInit.DISTRIBUTION.getWeightInitFunction(new NormalDistribution(mean, stdDev)),
WeightInit.DISTRIBUTION.getWeightInitFunction(new OrthogonalDistribution(gain)), WeightInit.DISTRIBUTION.getWeightInitFunction(new OrthogonalDistribution(gain)),
WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(value)), WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(value)),
WeightInit.VAR_SCALING_NORMAL_FAN_IN.getWeightInitFunction()}; new WeightInitVarScalingNormalFanIn(0.2)};
} }
private Distribution[] dl4jDistributions() { private Distribution[] dl4jDistributions() {

View File

@ -17,22 +17,16 @@
package org.deeplearning4j.nn.modelimport.keras.configurations; package org.deeplearning4j.nn.modelimport.keras.configurations;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
/** /**

View File

@ -31,7 +31,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;

View File

@ -24,22 +24,19 @@ import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.gradientcheck.GradientCheckUtil; import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.*; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -47,27 +44,25 @@ import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URL; import java.net.URL;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.StandardCopyOption; import java.nio.file.StandardCopyOption;
import java.util.*; import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/** /**
* Unit tests for end-to-end Keras model import. * Unit tests for end-to-end Keras model import.

View File

@ -21,7 +21,6 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.transferlearning.TransferLearning;
@ -31,11 +30,8 @@ import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
/** /**
* Import previously stored YOLO9000 Keras net from https://github.com/allanzelener/YAD2K. * Import previously stored YOLO9000 Keras net from https://github.com/allanzelener/YAD2K.

View File

@ -26,7 +26,6 @@ import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;

View File

@ -27,16 +27,11 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousC
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/** /**
* @author Max Pumperla * @author Max Pumperla

View File

@ -28,9 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -39,7 +36,6 @@ import java.util.Map;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/** /**
* @author Max Pumperla * @author Max Pumperla

View File

@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;

View File

@ -16,13 +16,11 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution; package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D;
import org.junit.Test; import org.junit.Test;

View File

@ -30,15 +30,11 @@ import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test; import org.junit.Test;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/** /**
* @author Max Pumperla * @author Max Pumperla

View File

@ -17,18 +17,14 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution; package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.Upsampling1D; import org.deeplearning4j.nn.conf.layers.Upsampling1D;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;

View File

@ -17,13 +17,11 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution; package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.Upsampling2D; import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;

View File

@ -17,12 +17,10 @@
package org.deeplearning4j.nn.modelimport.keras.layers.convolution; package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer; import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D;
import org.junit.Test; import org.junit.Test;

View File

@ -26,16 +26,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/** /**
* @author Max Pumperla * @author Max Pumperla

View File

@ -24,10 +24,12 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.junit.Test; import org.junit.Test;
import java.util.*; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;

View File

@ -24,11 +24,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.util.*; import java.util.*;

View File

@ -26,11 +26,7 @@ import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.*;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;

View File

@ -20,7 +20,6 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D; import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
@ -31,10 +30,8 @@ import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
/** /**

View File

@ -27,15 +27,14 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.*; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/** /**
* @author Max Pumperla * @author Max Pumperla

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer; import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;

View File

@ -33,14 +33,13 @@ import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.*; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/** /**
* @author Max Pumperla * @author Max Pumperla

View File

@ -16,15 +16,12 @@
package org.deeplearning4j.nn.modelimport.keras.optimizers; package org.deeplearning4j.nn.modelimport.keras.optimizers;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.util.DL4JFileUtils; import org.deeplearning4j.util.DL4JFileUtils;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;
@ -32,8 +29,6 @@ import java.io.InputStream;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.StandardCopyOption; import java.nio.file.StandardCopyOption;
import static java.io.File.createTempFile;
public class OptimizerImport extends BaseDL4JTest { public class OptimizerImport extends BaseDL4JTest {
@Test @Test

View File

@ -18,9 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.IOException; import java.io.IOException;

View File

@ -19,15 +19,11 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.IOException; import java.io.IOException;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
/** /**
* Import Keras Tokenizer * Import Keras Tokenizer

View File

@ -20,7 +20,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;

View File

@ -29,7 +29,6 @@ import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;

View File

@ -16,13 +16,11 @@
package org.deeplearning4j.nn.conf.layers.samediff; package org.deeplearning4j.nn.conf.layers.samediff;
import lombok.Data; import lombok.*;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -32,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
/** /**
@ -58,10 +57,12 @@ import java.util.Map;
public abstract class SameDiffLayer extends AbstractSameDiffLayer { public abstract class SameDiffLayer extends AbstractSameDiffLayer {
protected WeightInit weightInit; protected WeightInit weightInit;
protected Map<String,IWeightInit> paramWeightInit;
protected SameDiffLayer(Builder builder) { protected SameDiffLayer(Builder builder) {
super(builder); super(builder);
this.weightInit = builder.weightInit; this.weightInit = builder.weightInit;
this.paramWeightInit = builder.paramWeightInit;
} }
protected SameDiffLayer() { protected SameDiffLayer() {
@ -115,6 +116,7 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
public static abstract class Builder<T extends Builder<T>> extends AbstractSameDiffLayer.Builder<T> { public static abstract class Builder<T extends Builder<T>> extends AbstractSameDiffLayer.Builder<T> {
protected WeightInit weightInit = WeightInit.XAVIER; protected WeightInit weightInit = WeightInit.XAVIER;
protected Map<String,IWeightInit> paramWeightInit;
/** /**
* @param weightInit Weight initialization to use for the layer * @param weightInit Weight initialization to use for the layer
@ -123,5 +125,12 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
this.setWeightInit(weightInit); this.setWeightInit(weightInit);
return (T) this; return (T) this;
} }
public T weightInit(@NonNull String param, @NonNull IWeightInit weightInit){
if(paramWeightInit == null)
paramWeightInit = new HashMap<>();
paramWeightInit.put(param, weightInit);
return (T) this;
}
} }
} }

View File

@ -16,11 +16,14 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays; import java.util.Arrays;
@ -32,9 +35,17 @@ import java.util.Arrays;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@EqualsAndHashCode @Data
@NoArgsConstructor
public class WeightInitIdentity implements IWeightInit { public class WeightInitIdentity implements IWeightInit {
private Double scale;
public WeightInitIdentity(@JsonProperty("scale") Double scale){
this.scale = scale;
}
@Override @Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
if (shape[0] != shape[1]) { if (shape[0] != shape[1]) {
@ -59,6 +70,11 @@ public class WeightInitIdentity implements IWeightInit {
} else { } else {
ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0])); ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0]));
} }
if(scale != null){
ret.muli(scale);
}
INDArray flat = Nd4j.toFlattened(order, ret); INDArray flat = Nd4j.toFlattened(order, ret);
paramView.assign(flat); paramView.assign(flat);
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
@ -82,13 +98,16 @@ public class WeightInitIdentity implements IWeightInit {
indArrayIndices[i] = NDArrayIndex.point(shape[i] / 2); indArrayIndices[i] = NDArrayIndex.point(shape[i] / 2);
} }
paramView.assign(Nd4j.zeros(paramView.shape())); paramView.assign(0);
final INDArray params =paramView.reshape(order, shape); final INDArray params =paramView.reshape(order, shape);
for (int i = 0; i < shape[0]; i++) { for (int i = 0; i < shape[0]; i++) {
indArrayIndices[0] = NDArrayIndex.point(i); indArrayIndices[0] = NDArrayIndex.point(i);
indArrayIndices[1] = NDArrayIndex.point(i); indArrayIndices[1] = NDArrayIndex.point(i);
params.put(indArrayIndices, Nd4j.ones(1)); params.put(indArrayIndices, Nd4j.ones(1));
} }
if(scale != null){
params.muli(scale);
}
return params; return params;
} }
} }

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.weights;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.rng.distribution.Distribution; import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution; import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -146,14 +147,13 @@ public class WeightInitUtil {
paramView.assign(flat); paramView.assign(flat);
break; break;
case VAR_SCALING_NORMAL_FAN_IN: case VAR_SCALING_NORMAL_FAN_IN:
// TODO: needs to be truncated normal to match keras. Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanIn)));
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
break; break;
case VAR_SCALING_NORMAL_FAN_OUT: case VAR_SCALING_NORMAL_FAN_OUT:
Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut)); Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanOut)));
break; break;
case VAR_SCALING_NORMAL_FAN_AVG: case VAR_SCALING_NORMAL_FAN_AVG:
Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2)); Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(2.0 / (fanIn + fanOut))));
break; break;
case VAR_SCALING_UNIFORM_FAN_IN: case VAR_SCALING_UNIFORM_FAN_IN:
double scalingFanIn = 3.0 / Math.sqrt(fanIn); double scalingFanIn = 3.0 / Math.sqrt(fanIn);

View File

@ -16,22 +16,39 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* Gaussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2) * Truncated aussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2)
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@EqualsAndHashCode @Data
@NoArgsConstructor
public class WeightInitVarScalingNormalFanAvg implements IWeightInit { public class WeightInitVarScalingNormalFanAvg implements IWeightInit {
private Double scale;
public WeightInitVarScalingNormalFanAvg(Double scale){
this.scale = scale;
}
@Override @Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2)); double std;
if(scale == null){
std = Math.sqrt(2.0 / (fanIn + fanOut));
} else {
std = Math.sqrt(2.0 * scale / (fanIn + fanOut));
}
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
} }

View File

@ -16,23 +16,38 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import lombok.EqualsAndHashCode; import lombok.Data;
import org.apache.commons.math3.util.FastMath; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* Gaussian distribution with mean 0, variance 1.0/(fanIn) * Gaussian distribution with mean 0, variance {@code 1.0/(fanIn)}<br>
* If a scale is provided, use variance {@code scale/(fanIn)} instead
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@EqualsAndHashCode @Data
@NoArgsConstructor
public class WeightInitVarScalingNormalFanIn implements IWeightInit { public class WeightInitVarScalingNormalFanIn implements IWeightInit {
private Double scale;
public WeightInitVarScalingNormalFanIn(Double scale){
this.scale = scale;
}
@Override @Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
// TODO: needs to be truncated normal to match keras. double std;
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn)); if(scale == null){
std = Math.sqrt(1.0 / fanIn);
} else {
std = Math.sqrt(scale / fanIn);
}
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
} }

View File

@ -16,22 +16,40 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* Gaussian distribution with mean 0, variance 1.0/(fanOut) * Truncated normal distribution with mean 0, variance 1.0/(fanOut)<br>
* If a scale is provided, variance is scale / fanOut
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@EqualsAndHashCode @Data
@NoArgsConstructor
public class WeightInitVarScalingNormalFanOut implements IWeightInit { public class WeightInitVarScalingNormalFanOut implements IWeightInit {
private Double scale;
public WeightInitVarScalingNormalFanOut(Double scale){
this.scale = scale;
}
@Override @Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut)); double std;
if(scale == null){
std = Math.sqrt(1.0 / fanOut);
} else {
std = Math.sqrt(scale / fanOut);
}
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
} }

View File

@ -16,7 +16,9 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -25,12 +27,22 @@ import org.nd4j.linalg.factory.Nd4j;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@EqualsAndHashCode @Data
@NoArgsConstructor
public class WeightInitVarScalingUniformFanAvg implements IWeightInit { public class WeightInitVarScalingUniformFanAvg implements IWeightInit {
private Double scale;
public WeightInitVarScalingUniformFanAvg(Double scale){
this.scale = scale;
}
@Override @Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2); double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
if(scale != null)
scalingFanAvg *= scale;
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg)); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }

View File

@ -16,21 +16,34 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* Uniform U[-a,a] with a=3.0/(fanIn) * Uniform U[-a,a] with a=3.0/(fanIn)<br>
* If a scale is provided, a = 3.0 * scale / (fanIn)
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@EqualsAndHashCode @NoArgsConstructor
@Data
public class WeightInitVarScalingUniformFanIn implements IWeightInit { public class WeightInitVarScalingUniformFanIn implements IWeightInit {
private Double scale;
public WeightInitVarScalingUniformFanIn(Double scale){
this.scale = scale;
}
@Override @Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
double scalingFanIn = 3.0 / Math.sqrt(fanIn); double scalingFanIn = 3.0 / Math.sqrt(fanIn);
if(scale != null)
scalingFanIn *= scale;
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn)); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }

View File

@ -16,21 +16,33 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* Uniform U[-a,a] with a=3.0/(fanOut) * Uniform U[-a,a] with a=3.0/(fanOut)<br>
* If a scale is provided, a = 3.0 * scale / fanOut
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@EqualsAndHashCode @Data
@NoArgsConstructor
public class WeightInitVarScalingUniformFanOut implements IWeightInit { public class WeightInitVarScalingUniformFanOut implements IWeightInit {
private Double scale;
public WeightInitVarScalingUniformFanOut(Double scale){
this.scale = scale;
}
@Override @Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
double scalingFanOut = 3.0 / Math.sqrt(fanOut); double scalingFanOut = 3.0 / Math.sqrt(fanOut);
if(scale != null)
scalingFanOut *= scale;
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut)); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }

View File

@ -25,7 +25,7 @@ elseif (APPLE)
elseif(WIN32) elseif(WIN32)
set(X86_BUILD true) set(X86_BUILD true)
if (CUDA_BLAS) if (CUDA_BLAS)
set(CMAKE_CXX_FLAGS_RELEASE " /O2 -D_RELEASE=true /wd4804") set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305") set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
else() else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")

View File

@ -3607,6 +3607,13 @@ public class Shape {
return ArrayUtil.prodLong(shape); return ArrayUtil.prodLong(shape);
} }
public static long lengthOf(int[] shape) {
if (shape.length == 0)
return 1L;
else
return ArrayUtil.prodLong(shape);
}
/** /**
* Calculate the length of the buffer required to store the given shape with the given strides * Calculate the length of the buffer required to store the given shape with the given strides
* *

View File

@ -28,11 +28,6 @@ import java.util.List;
* @author Adam Gibson * @author Adam Gibson
*/ */
public class SamplingDataSetIterator implements DataSetIterator { public class SamplingDataSetIterator implements DataSetIterator {
/**
*
*/
private static final long serialVersionUID = -2700563801361726914L;
private DataSet sampleFrom; private DataSet sampleFrom;
private int batchSize; private int batchSize;
private int totalNumberSamples; private int totalNumberSamples;
@ -145,6 +140,4 @@ public class SamplingDataSetIterator implements DataSetIterator {
numTimesSampled++; numTimesSampled++;
return ret; return ret;
} }
} }

View File

@ -1164,26 +1164,15 @@ public class Nd4j {
* @param type the opType to create * @param type the opType to create
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(int[] shape, DataType type) { public static DataBuffer createBuffer(@NonNull int[] shape, @NonNull DataType type) {
long length = ArrayUtil.prodLong(shape); return createBuffer(ArrayUtil.toLongArray(shape), type);
if (type == DataType.INT)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (type == DataType.LONG)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (type == DataType.HALF)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (type == DataType.DOUBLE)
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
else
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
} }
/** /**
* See {@link #createBuffer(int[], DataType)} * See {@link #createBuffer(int[], DataType)}
*/ */
public static DataBuffer createBuffer(long[] shape, DataType type) { public static DataBuffer createBuffer(@NonNull long[] shape, @NonNull DataType type) {
long length = ArrayUtil.prodLong(shape); long length = Shape.lengthOf(shape);
switch (type) { switch (type) {
case BOOL: case BOOL:
@ -1229,14 +1218,14 @@ public class Nd4j {
* @return the created buffer. * @return the created buffer.
*/ */
public static DataBuffer createBufferDetached(int[] shape, DataType type) { public static DataBuffer createBufferDetached(int[] shape, DataType type) {
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type); return createBufferDetachedImpl( Shape.lengthOf(shape), type);
} }
/** /**
* See {@link #createBufferDetached(int[], DataType)} * See {@link #createBufferDetached(int[], DataType)}
*/ */
public static DataBuffer createBufferDetached(long[] shape, DataType type) { public static DataBuffer createBufferDetached(long[] shape, DataType type) {
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type); return createBufferDetachedImpl( Shape.lengthOf(shape), type);
} }
// used by createBufferDetached(long[] DataType) and createBufferDetached(int[] , DataType) // used by createBufferDetached(long[] DataType) and createBufferDetached(int[] , DataType)