From 09a827fb6dcea3f810bd857483e84b4e28501784 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 16 Nov 2019 17:04:29 +1100 Subject: [PATCH] Fixes and pre-release QA (#51) * #8395 Keras import - support scaled identity weight init Signed-off-by: AlexDBlack * More Keras scaled weight init fixes Signed-off-by: AlexDBlack * #8352 Deprecate duplicate SamplingDataSetIterator class Signed-off-by: AlexDBlack * Remove /O2 optimization for faster CUDA build Signed-off-by: AlexDBlack * Tweak regression test precision for CUDA Signed-off-by: AlexDBlack * Fix edge cases for buffer creation Signed-off-by: AlexDBlack * Update MKLDNN validation tests to new helper enable/disable settings Signed-off-by: AlexDBlack * Delete debugging class Signed-off-by: AlexDBlack * MKLDNN test - add proper skip for CUDA backend Signed-off-by: AlexDBlack * Align WeightInitUtil with weight init classes Signed-off-by: AlexDBlack * Fix for SameDiff test layers weight init when using IWeightInit classes Signed-off-by: AlexDBlack --- .../LayerHelperValidationUtil.java | 56 ++++++++- .../org/deeplearning4j/TestBatchNormBp.java | 107 ------------------ .../testlayers/MinimalSameDiffDense.java | 15 ++- .../samediff/testlayers/SameDiffConv.java | 14 ++- .../samediff/testlayers/SameDiffDense.java | 12 +- .../nn/mkldnn/ValidateMKLDNN.java | 2 + .../regressiontest/RegressionTest100b4.java | 8 +- .../iterator/SamplingDataSetIterator.java | 98 +--------------- .../nn/modelimport/keras/Hdf5Archive.java | 2 +- .../keras/KerasSequentialModel.java | 1 - .../advanced/activations/KerasPReLU.java | 10 +- .../KerasAtrousConvolution1D.java | 10 +- .../KerasAtrousConvolution2D.java | 9 +- .../convolutional/KerasConvolution.java | 2 - .../convolutional/KerasConvolution1D.java | 10 +- .../convolutional/KerasConvolution2D.java | 10 +- .../convolutional/KerasConvolution3D.java | 10 +- .../convolutional/KerasDeconvolution2D.java | 10 +- .../KerasDepthwiseConvolution2D.java | 10 +- .../KerasSeparableConvolution2D.java | 16 +-- .../convolutional/KerasUpsampling3D.java | 1 - .../convolutional/KerasZeroPadding3D.java | 1 - .../keras/layers/core/KerasDense.java | 10 +- .../keras/layers/core/KerasFlatten.java | 1 - .../keras/layers/core/KerasRepeatVector.java | 1 - .../keras/layers/core/KerasReshape.java | 2 - .../layers/embeddings/KerasEmbedding.java | 10 +- .../layers/local/KerasLocallyConnected1D.java | 11 +- .../layers/local/KerasLocallyConnected2D.java | 16 +-- .../KerasBatchNormalization.java | 1 - .../keras/layers/recurrent/KerasLSTM.java | 15 +-- .../layers/recurrent/KerasSimpleRnn.java | 15 +-- .../sequence/TimeSeriesGenerator.java | 3 - .../KerasFlattenRnnPreprocessor.java | 3 +- .../preprocessors/ReshapePreprocessor.java | 4 +- ...ensorFlowCnnToFeedForwardPreProcessor.java | 2 +- .../keras/utils/DL4JKerasModelValidator.java | 13 --- .../keras/utils/KerasActivationUtils.java | 1 - .../keras/utils/KerasInitilizationUtils.java | 91 +++++++-------- .../keras/utils/KerasModelUtils.java | 1 - .../nn/modelimport/keras/KerasTestUtils.java | 2 - .../nn/modelimport/keras/MiscTests.java | 2 - .../configurations/FullModelComparisons.java | 2 - .../Keras1ModelConfigurationTest.java | 1 - .../Keras2ModelConfigurationTest.java | 2 - .../KerasInitilizationTest.java | 6 +- .../configurations/KerasModelImportTest.java | 6 - .../keras/e2e/KerasLambdaTest.java | 1 - .../keras/e2e/KerasModelEndToEndTest.java | 23 ++-- .../keras/e2e/KerasYolo9000PredictTest.java | 4 - .../keras/e2e/KerasYolo9000Test.java | 1 - .../KerasAtrousConvolution1DTest.java | 5 - .../convolution/KerasConvolution3DTest.java | 4 - .../convolution/KerasCropping1DTest.java | 1 - .../convolution/KerasCropping3DTest.java | 2 - .../KerasDepthwiseConvolution2DTest.java | 4 - .../convolution/KerasUpsampling1DTest.java | 4 - .../convolution/KerasUpsampling2DTest.java | 2 - .../convolution/KerasZeroPadding3DTest.java | 2 - .../keras/layers/core/KerasDenseTest.java | 5 - .../keras/layers/core/KerasPermuteTest.java | 6 +- .../keras/layers/core/KerasReshapeTest.java | 2 +- .../layers/embeddings/KerasEmbeddingTest.java | 6 +- .../local/KerasLocallyConnected1DTest.java | 3 - .../local/KerasLocallyConnected2DTest.java | 9 +- .../layers/pooling/KerasPooling3DTest.java | 1 - .../keras/layers/recurrent/KerasLSTMTest.java | 9 +- .../keras/optimizers/OptimizerImport.java | 5 - .../TimeSeriesGeneratorImportTest.java | 2 - .../text/TokenizerImportTest.java | 6 +- .../preprocessing/text/TokenizerTest.java | 1 - .../weights/KerasWeightSettingTests.java | 1 - .../conf/layers/samediff/SameDiffLayer.java | 17 ++- .../nn/weights/WeightInitIdentity.java | 23 +++- .../nn/weights/WeightInitUtil.java | 8 +- .../WeightInitVarScalingNormalFanAvg.java | 23 +++- .../WeightInitVarScalingNormalFanIn.java | 27 ++++- .../WeightInitVarScalingNormalFanOut.java | 24 +++- .../WeightInitVarScalingUniformFanAvg.java | 14 ++- .../WeightInitVarScalingUniformFanIn.java | 17 ++- .../WeightInitVarScalingUniformFanOut.java | 16 ++- libnd4j/CMakeLists.txt | 2 +- .../java/org/nd4j/linalg/api/shape/Shape.java | 7 ++ .../api/iterator/SamplingDataSetIterator.java | 7 -- .../java/org/nd4j/linalg/factory/Nd4j.java | 23 +--- 85 files changed, 378 insertions(+), 574 deletions(-) delete mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestBatchNormBp.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java index 59ef8c28e..e3923c4ff 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java @@ -35,6 +35,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.ops.transforms.Transforms; import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.util.*; import static org.junit.Assert.*; @@ -63,6 +64,30 @@ public class LayerHelperValidationUtil { private DataSetIterator data; } + public static void disableCppHelpers(){ + try { + Class c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment"); + Method m = c.getMethod("getInstance"); + Object instance = m.invoke(null); + Method m2 = c.getMethod("allowHelpers", boolean.class); + m2.invoke(instance, false); + } catch (Throwable t){ + throw new RuntimeException(t); + } + } + + public static void enableCppHelpers(){ + try{ + Class c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment"); + Method m = c.getMethod("getInstance"); + Object instance = m.invoke(null); + Method m2 = c.getMethod("allowHelpers", boolean.class); + m2.invoke(instance, true); + } catch (Throwable t){ + throw new RuntimeException(t); + } + } + public static void validateMLN(MultiLayerNetwork netOrig, TestCase t){ assertNotNull(t.getAllowHelpersForClasses()); assertFalse(t.getAllowHelpersForClasses().isEmpty()); @@ -95,7 +120,13 @@ public class LayerHelperValidationUtil { for (boolean train : new boolean[]{false, true}) { assertEquals(net1NoHelper.params(), net2With.params()); String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: "); - List ff1 = net1NoHelper.feedForward(t.getFeatures(), train); + List ff1; + try { + disableCppHelpers(); + ff1 = net1NoHelper.feedForward(t.getFeatures(), train); + } finally { + enableCppHelpers(); + } List ff2 = net2With.feedForward(t.getFeatures(), train); List paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet()); Collections.sort(paramKeys); @@ -131,7 +162,13 @@ public class LayerHelperValidationUtil { log.info("Forward pass, max relative error: " + layerName + " - " + maxRE); } - INDArray out1 = net1NoHelper.output(t.getFeatures(), train); + INDArray out1; + try { + disableCppHelpers(); + out1 = net1NoHelper.output(t.getFeatures(), train); + } finally { + enableCppHelpers(); + } INDArray out2 = net2With.output(t.getFeatures(), train); INDArray relError = relError(out1, out2, t.getMinAbsError()); double maxRE = relError.maxNumber().doubleValue(); @@ -148,7 +185,13 @@ public class LayerHelperValidationUtil { Preconditions.checkNotNull(t.getLabels(), "Labels are not set (null)"); log.info("Validation - checking scores"); - double s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels())); + double s1; + try { + disableCppHelpers(); + s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels())); + } finally { + enableCppHelpers(); + } double s2 = net2With.score(new DataSet(t.getFeatures(), t.getLabels())); double re = relError(s1, s2); @@ -168,7 +211,12 @@ public class LayerHelperValidationUtil { net2With.setInput(t.getFeatures()); net2With.setLabels(t.getLabels()); - net1NoHelper.computeGradientAndScore(); + try { + disableCppHelpers(); + net1NoHelper.computeGradientAndScore(); + } finally { + enableCppHelpers(); + } net2With.computeGradientAndScore(); List paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestBatchNormBp.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestBatchNormBp.java deleted file mode 100644 index f34ce65f0..000000000 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestBatchNormBp.java +++ /dev/null @@ -1,107 +0,0 @@ -package org.deeplearning4j; - -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.conf.layers.BatchNormalization; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; - -import java.lang.reflect.Field; - -import static junit.framework.TestCase.*; - -public class TestBatchNormBp { - - @Test - public void test(){ - Nd4j.getRandom().setSeed(12345); -// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4); - INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15); - INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3); - INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3); - INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); -// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3); -// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3); - INDArray gamma = Nd4j.rand(DataType.FLOAT, 3); - INDArray beta = Nd4j.rand(DataType.FLOAT, 3); - double e = 1e-5; - - INDArray dLdIn = in.ulike(); - INDArray dLdm = mean.ulike(); - INDArray dLdv = var.ulike(); - INDArray dLdg = gamma.ulike(); - INDArray dLdb = beta.ulike(); - - DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp") - .addInputs(in, mean, var, eps, gamma, beta) - .addIntegerArguments( - 1, //Apply scale - 1, //Apply beta - 1) //Axis (NCHW) - .addFloatingPointArguments(e) - .addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb) - .build(); - - Nd4j.exec(op); - System.out.println(dLdIn); - } - - @Test - public void compareImpls() throws Exception { - - Nd4j.getRandom().setSeed(12345); - INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15); - INDArray mean = in.mean(0, 2, 3).reshape(1,3); - INDArray var = in.var(0, 2, 3).reshape(1,3); - INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); - INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3); - INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3); - double e = 1e-3; - - INDArray dLdIn = in.ulike(); - INDArray dLdm = mean.ulike(); - INDArray dLdv = var.ulike(); - INDArray dLdg = gamma.ulike(); - INDArray dLdb = beta.ulike(); - - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .inferenceWorkspaceMode(WorkspaceMode.NONE) - .trainingWorkspaceMode(WorkspaceMode.NONE) - .list() - .layer(new BatchNormalization.Builder().nIn(3).nOut(3).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0); - assertNotNull(bn.getHelper()); - Field f = bn.getClass().getDeclaredField("helper"); - f.setAccessible(true); - f.set(bn, null); - assertNull(bn.getHelper()); - - - MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT); - - net.output(in, true); - bn.setInput(in, LayerWorkspaceMgr.noWorkspaces()); - Pair p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); - - h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces()); - Pair pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces()); - - INDArray dldin_dl4j = p.getSecond(); - - System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond())); - } - -} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java index 1b8e7ded9..9cbbccaa7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java @@ -70,8 +70,19 @@ public class MinimalSameDiffDense extends SameDiffLayer { @Override public void initializeParameters(Map params) { - params.get(DefaultParamInitializer.BIAS_KEY).assign(0); - initWeights(nIn, nOut, weightInit, params.get(DefaultParamInitializer.WEIGHT_KEY)); + String b = DefaultParamInitializer.BIAS_KEY; + if(paramWeightInit != null && paramWeightInit.containsKey(b)){ + paramWeightInit.get(b).init(nIn, nOut, params.get(b).shape(), 'c', params.get(b)); + } else { + params.get(DefaultParamInitializer.BIAS_KEY).assign(0); + } + + String w = DefaultParamInitializer.WEIGHT_KEY; + if(paramWeightInit != null && paramWeightInit.containsKey(w)){ + paramWeightInit.get(w).init(nIn, nOut, params.get(w).shape(), 'c', params.get(w)); + } else { + initWeights(nIn, nOut, weightInit, params.get(DefaultParamInitializer.WEIGHT_KEY)); + } } //OPTIONAL methods: diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java index 778b95dc7..1be09182c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java @@ -109,13 +109,17 @@ public class SameDiffConv extends SameDiffLayer { @Override public void initializeParameters(Map params) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + double fanIn = nIn * kernel[0] * kernel[1]; + double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]); for (Map.Entry e : params.entrySet()) { - if (ConvolutionParamInitializer.BIAS_KEY.equals(e.getKey())) { - e.getValue().assign(0); + if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){ + paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue()); } else { - double fanIn = nIn * kernel[0] * kernel[1]; - double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]); - WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue()); + if (ConvolutionParamInitializer.BIAS_KEY.equals(e.getKey())) { + e.getValue().assign(0); + } else { + WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue()); + } } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java index 3da6e8f1c..630b6059c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java @@ -88,11 +88,15 @@ public class SameDiffDense extends SameDiffLayer { @Override public void initializeParameters(Map params){ for(Map.Entry e : params.entrySet()){ - if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){ - e.getValue().assign(0.0); + if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){ + paramWeightInit.get(e.getKey()).init(nIn, nOut, e.getValue().shape(), 'c', e.getValue()); } else { - //Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer - WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', e.getValue()); + if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){ + e.getValue().assign(0.0); + } else { + //Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer + WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', e.getValue()); + } } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java index f65e48f44..7013311ba 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java @@ -50,6 +50,7 @@ import static org.junit.Assume.assumeTrue; public class ValidateMKLDNN extends BaseDL4JTest { + @Test public void validateConvSubsampling() throws Exception { //Only run test if using nd4j-native backend @@ -268,6 +269,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { @Test public void compareBatchNormBackward() throws Exception { + assumeTrue(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")); Nd4j.getRandom().setSeed(12345); INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index d1112899f..a4883ea07 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -339,7 +339,13 @@ public class RegressionTest100b4 extends BaseDL4JTest { INDArray outAct = net.output(in); - assertEquals(outExp, outAct); + //19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct + if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){ + assertEquals(outExp, outAct); + } else { + boolean eq = outExp.equalsWithEps(outAct, 0.1); + assertTrue(eq); + } } @Test diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/SamplingDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/SamplingDataSetIterator.java index 62ee85407..32e4c61d3 100755 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/SamplingDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/SamplingDataSetIterator.java @@ -24,101 +24,11 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.util.List; /** - * A wrapper for a dataset to sample from. - * This will randomly sample from the given dataset. - * @author Adam GIbson + * @deprecated Use {@link org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator} */ -public class SamplingDataSetIterator implements DataSetIterator { - - /** - * - */ - private static final long serialVersionUID = -2700563801361726914L; - private DataSet sampleFrom; - private int batchSize; - private int totalNumberSamples; - private int numTimesSampled; - @Getter - private DataSetPreProcessor preProcessor; - - /** - * - * @param sampleFrom the dataset to sample from - * @param batchSize the batch size to sample - * @param totalNumberSamples the sample size - */ +@Deprecated +public class SamplingDataSetIterator extends org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator { public SamplingDataSetIterator(DataSet sampleFrom, int batchSize, int totalNumberSamples) { - super(); - this.sampleFrom = sampleFrom; - this.batchSize = batchSize; - this.totalNumberSamples = totalNumberSamples; + super(sampleFrom, batchSize, totalNumberSamples); } - - @Override - public boolean hasNext() { - return numTimesSampled < totalNumberSamples; - } - - @Override - public DataSet next() { - DataSet ret = sampleFrom.sample(batchSize); - numTimesSampled += batchSize; - return ret; - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - - @Override - public int inputColumns() { - return sampleFrom.numInputs(); - } - - @Override - public int totalOutcomes() { - return sampleFrom.numOutcomes(); - } - - @Override - public boolean resetSupported() { - return true; - } - - @Override - public boolean asyncSupported() { - return true; - } - - @Override - public void reset() { - numTimesSampled = 0; - } - - @Override - public int batch() { - return batchSize; - } - - @Override - public void setPreProcessor(DataSetPreProcessor preProcessor) { - this.preProcessor = preProcessor; - } - - @Override - public List getLabels() { - return null; - } - - - @Override - public DataSet next(int num) { - DataSet ret = sampleFrom.sample(num); - numTimesSampled++; - return ret; - } - - - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java index 83d138d5c..a5ea8efca 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras; import lombok.extern.slf4j.Slf4j; +import org.bytedeco.hdf5.*; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.Loader; @@ -32,7 +33,6 @@ import java.lang.Exception; import java.util.ArrayList; import java.util.List; -import org.bytedeco.hdf5.*; import static org.bytedeco.hdf5.global.hdf5.*; /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java index 529cf729c..d163c0776 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java @@ -17,7 +17,6 @@ package org.deeplearning4j.nn.modelimport.keras; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java index 15de6fc53..8877d8b5a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java @@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.PReLULayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -27,9 +26,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.PReLUParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import java.util.HashMap; @@ -79,14 +77,12 @@ public class KerasPReLU extends KerasLayer { LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, ALPHA_CONSTRAINT, conf, kerasMajorVersion); - Pair init = getWeightInitFromConfig(layerConfig, ALPHA_INIT, + IWeightInit init = getWeightInitFromConfig(layerConfig, ALPHA_INIT, enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); long[] axes = getSharedAxes(layerConfig); PReLULayer.Builder builder = new PReLULayer.Builder().sharedAxes(axes) - .weightInit(weightInit.getWeightInitFunction(distribution)).name(layerName); + .weightInit(init).name(layerName); if (weightConstraint != null){ builder.constrainWeights(weightConstraint); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java index b7fa269f7..d7a4ab699 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java @@ -17,14 +17,12 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.primitives.Pair; +import org.deeplearning4j.nn.weights.IWeightInit; import java.util.Map; @@ -83,15 +81,13 @@ public class KerasAtrousConvolution1D extends KerasConvolution { LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) + .weightInit(init) .dilation(getDilationRate(layerConfig, 1, conf, true)[0]) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java index aa602bb3c..dd374992a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution2D.java @@ -17,14 +17,12 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.primitives.Pair; +import org.deeplearning4j.nn.weights.IWeightInit; import java.util.Map; @@ -84,14 +82,13 @@ public class KerasAtrousConvolution2D extends KerasConvolution { LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction()) + .weightInit(init) .dilation(getDilationRate(layerConfig, 2, conf, true)) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java index c4e66f6ef..f1d2f0210 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; -import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.params.ConvolutionParamInitializer; @@ -30,7 +29,6 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.HashMap; import java.util.Map; -import java.util.Set; import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java index 33512eb33..3da88d3b1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java @@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; @@ -30,10 +29,9 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.params.ConvolutionParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -94,15 +92,13 @@ public class KerasConvolution1D extends KerasConvolution { LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) + .weightInit(init) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java index 3c1d9f7d2..e9c74e78c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java @@ -21,14 +21,12 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.primitives.Pair; +import org.deeplearning4j.nn.weights.IWeightInit; import java.util.Map; @@ -87,10 +85,8 @@ public class KerasConvolution2D extends KerasConvolution { numTrainableParams = hasBias ? 2 : 1; int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); @@ -100,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution { ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) + .weightInit(init) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution3D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution3D.java index 8da12a726..ccd776306 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution3D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution3D.java @@ -21,15 +21,13 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.primitives.Pair; +import org.deeplearning4j.nn.weights.IWeightInit; import java.util.Map; @@ -88,10 +86,8 @@ public class KerasConvolution3D extends KerasConvolution { numTrainableParams = hasBias ? 2 : 1; int[] dilationRate = getDilationRate(layerConfig, 3, conf, false); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); @@ -101,7 +97,7 @@ public class KerasConvolution3D extends KerasConvolution { Convolution3D.Builder builder = new Convolution3D.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) + .weightInit(init) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 3, conf, kerasMajorVersion)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDeconvolution2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDeconvolution2D.java index 33e02ae6f..92d9f3af8 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDeconvolution2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDeconvolution2D.java @@ -20,14 +20,12 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Deconvolution2D; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.primitives.Pair; +import org.deeplearning4j.nn.weights.IWeightInit; import java.util.Map; @@ -86,10 +84,8 @@ public class KerasDeconvolution2D extends KerasConvolution { numTrainableParams = hasBias ? 2 : 1; int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); @@ -99,7 +95,7 @@ public class KerasDeconvolution2D extends KerasConvolution { Deconvolution2D.Builder builder = new Deconvolution2D.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) + .weightInit(init) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java index f27d3ff08..c72de75a6 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -30,9 +29,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils; import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import java.util.Collections; import java.util.HashMap; @@ -126,10 +124,8 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution { numTrainableParams = hasBias ? 2 : 1; int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); - Pair depthWiseInit = getWeightInitFromConfig(layerConfig, + IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit depthWeightInit = depthWiseInit.getFirst(); - Distribution depthDistribution = depthWiseInit.getSecond(); val nIn = getNInFromConfig(previousLayers); @@ -152,7 +148,7 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution { .nIn(nIn) .nOut(nIn * depthMultiplier) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(depthWeightInit.getWeightInitFunction(depthDistribution)) + .weightInit(depthWiseInit) .depthMultiplier(depthMultiplier) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasSeparableConvolution2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasSeparableConvolution2D.java index 67eba9bf1..cd052bbb7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasSeparableConvolution2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasSeparableConvolution2D.java @@ -20,7 +20,6 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; @@ -28,9 +27,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils; import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -93,17 +91,13 @@ public class KerasSeparableConvolution2D extends KerasConvolution { int depthMultiplier = getDepthMultiplier(layerConfig, conf); - Pair depthWiseInit = getWeightInitFromConfig(layerConfig, + IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit depthWeightInit = depthWiseInit.getFirst(); - Distribution depthDistribution = depthWiseInit.getSecond(); - Pair pointWiseInit = getWeightInitFromConfig(layerConfig, + IWeightInit pointWiseInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_POINT_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit pointWeightInit = pointWiseInit.getFirst(); - Distribution pointDistribution = pointWiseInit.getSecond(); - if (depthWeightInit != pointWeightInit || depthDistribution != pointDistribution) + if ( !depthWiseInit.getClass().equals(pointWiseInit.getClass()) ) if (enforceTrainingConfig) throw new UnsupportedKerasConfigurationException( "Specifying different initialization for depth- and point-wise weights not supported."); @@ -126,7 +120,7 @@ public class KerasSeparableConvolution2D extends KerasConvolution { SeparableConvolution2D.Builder builder = new SeparableConvolution2D.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(depthWeightInit.getWeightInitFunction(depthDistribution)) + .weightInit(depthWiseInit) .depthMultiplier(depthMultiplier) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasUpsampling3D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasUpsampling3D.java index a9c1054f1..98aabb3ee 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasUpsampling3D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasUpsampling3D.java @@ -17,7 +17,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Upsampling2D; import org.deeplearning4j.nn.conf.layers.Upsampling3D; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasZeroPadding3D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasZeroPadding3D.java index 387b826f5..7c840d301 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasZeroPadding3D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasZeroPadding3D.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer; -import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java index d840370d8..296b5dabf 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.params.DefaultParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -95,15 +93,13 @@ public class KerasDense extends KerasLayer { LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)) .dropOut(this.dropout).activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) + .weightInit(init) .biasInit(0.0) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .hasBias(hasBias); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java index d2aeb75c3..e0a6628a2 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java @@ -22,7 +22,6 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java index 45f9ddadd..41254e221 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java @@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.DropoutLayer; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java index 1275cf5a9..6a5e1ff2a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java @@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.val; -import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -26,7 +25,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; -import org.nd4j.linalg.util.ArrayUtil; import java.util.List; import java.util.Map; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java index 2a34f707c..1ee13c0b0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -30,11 +29,10 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.DefaultParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -106,10 +104,8 @@ public class KerasEmbedding extends KerasLayer { "in DL4J, apply masking as a pre-processing step to your input." + "See http://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent#masking for more on this."); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), conf, kerasMajorVersion); @@ -121,7 +117,7 @@ public class KerasEmbedding extends KerasLayer { .inferInputLength(inferInputLength) .nOut(getNOutFromConfig(layerConfig, conf)) .dropOut(this.dropout).activation(Activation.IDENTITY) - .weightInit(weightInit.getWeightInitFunction(distribution)) + .weightInit(init) .biasInit(0.0) .l1(this.weightL1Regularization) .l2(this.weightL2Regularization) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1D.java index f08e462ca..d6fed55fe 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1D.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LocallyConnected1D; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; @@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.params.ConvolutionParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -90,11 +88,8 @@ public class KerasLocallyConnected1D extends KerasConvolution { numTrainableParams = hasBias ? 2 : 1; int[] dilationRate = getDilationRate(layerConfig, 1, conf, false); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - // TODO: take care of distribution and bias init - //Distribution distribution = init.getSecond(); LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); @@ -104,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution { LocallyConnected1D.Builder builder = new LocallyConnected1D.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit) + .weightInit(conf.getKERAS_PARAM_NAME_W(), init) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2D.java index 5c2ab641b..550c20d01 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2D.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LocallyConnected2D; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; @@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.params.ConvolutionParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import java.util.HashMap; import java.util.Map; @@ -39,9 +37,7 @@ import java.util.Map; import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.*; import static org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils.getActivationFromConfig; import static org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils.getWeightInitFromConfig; -import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getHasBiasFromConfig; -import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getNOutFromConfig; -import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights; +import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.*; /** @@ -92,11 +88,9 @@ public class KerasLocallyConnected2D extends KerasConvolution { numTrainableParams = hasBias ? 2 : 1; int[] dilationRate = getDilationRate(layerConfig, 2, conf, false); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - // TODO: take care of distribution and bias init - //Distribution distribution = init.getSecond(); + // TODO: take care of bias init LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion); @@ -106,7 +100,7 @@ public class KerasLocallyConnected2D extends KerasConvolution { LocallyConnected2D.Builder builder = new LocallyConnected2D.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit) + .weightInit(conf.getKERAS_PARAM_NAME_W(), init) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java index ff8d4d91f..7f7d8dc4c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalization.java @@ -31,7 +31,6 @@ import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Set; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index f04752936..7d5603261 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.LSTM; @@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.LSTMParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -151,15 +150,11 @@ public class KerasLSTM extends KerasLayer { throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { super(layerConfig, enforceTrainingConfig); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); - Pair recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), + IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit recurrentWeightInit = recurrentInit.getFirst(); - Distribution recurrentDistribution = recurrentInit.getSecond(); boolean hasBias = getHasBiasFromConfig(layerConfig, conf); @@ -186,8 +181,8 @@ public class KerasLSTM extends KerasLayer { .nOut(getNOutFromConfig(layerConfig, conf)) .dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) - .weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution)) + .weightInit(init) + .weightInitRecurrent(recurrentInit) .biasInit(0.0) // TODO: this is incorrect .l1(this.weightL1Regularization) .l2(this.weightL2Regularization); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index 615405fae..6f5edf597 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.Layer; @@ -34,7 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -124,15 +123,11 @@ public class KerasSimpleRnn extends KerasLayer { throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { super(layerConfig, enforceTrainingConfig); - Pair init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), + IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit weightInit = init.getFirst(); - Distribution distribution = init.getSecond(); - Pair recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), + IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - WeightInit recurrentWeightInit = recurrentInit.getFirst(); - Distribution recurrentDistribution = recurrentInit.getSecond(); Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); this.returnSequences = (Boolean) innerConfig.get(conf.getLAYER_FIELD_RETURN_SEQUENCES()); @@ -154,8 +149,8 @@ public class KerasSimpleRnn extends KerasLayer { .nOut(getNOutFromConfig(layerConfig, conf)) .dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) - .weightInit(weightInit.getWeightInitFunction(distribution)) - .weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution)) + .weightInit(init) + .weightInitRecurrent(recurrentInit) .biasInit(0.0) .l1(this.weightL1Regularization) .l2(this.weightL2Regularization); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGenerator.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGenerator.java index 94498b976..2a81886e0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGenerator.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGenerator.java @@ -20,9 +20,7 @@ import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; import lombok.Data; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; -import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -31,7 +29,6 @@ import org.nd4j.linalg.primitives.Pair; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; -import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java index 3e18ebe3e..25aa73a06 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java @@ -22,9 +22,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.workspace.ArrayType; -import org.nd4j.linalg.api.ndarray.INDArray; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index f94adf713..77c6369c5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -19,17 +19,15 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessors; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; - import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.workspace.ArrayType; -import org.nd4j.linalg.api.ndarray.INDArray; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java index f80863a03..db7d2e990 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java @@ -20,9 +20,9 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; -import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.shade.jackson.annotation.JsonCreator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/DL4JKerasModelValidator.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/DL4JKerasModelValidator.java index cd4461082..2ace14aa3 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/DL4JKerasModelValidator.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/DL4JKerasModelValidator.java @@ -1,28 +1,15 @@ package org.deeplearning4j.nn.modelimport.keras.utils; import lombok.NonNull; -import org.apache.commons.io.IOUtils; -import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; -import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.util.ModelSerializer; import org.nd4j.validation.Nd4jCommonValidator; import org.nd4j.validation.ValidationResult; -import java.io.BufferedReader; import java.io.File; -import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Collections; -import java.util.List; -import java.util.zip.ZipEntry; -import java.util.zip.ZipFile; /** * A utility for validating serialized Keras sequential and functional models for import into DL4J diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java index bb2bb1ca0..f0ddfd912 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java @@ -21,7 +21,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.activations.impl.*; import java.util.Map; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java index b86b83be1..b4b5e6564 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.java @@ -21,8 +21,7 @@ import org.deeplearning4j.nn.conf.distribution.*; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.primitives.Pair; +import org.deeplearning4j.nn.weights.*; import java.util.HashMap; import java.util.Map; @@ -42,76 +41,71 @@ public class KerasInitilizationUtils { * @return DL4J weight initialization enum * @see WeightInit */ - public static Pair mapWeightInitialization(String kerasInit, - KerasLayerConfiguration conf, - Map initConfig, - int kerasMajorVersion) + public static IWeightInit mapWeightInitialization(String kerasInit, + KerasLayerConfiguration conf, + Map initConfig, + int kerasMajorVersion) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException { // TODO: Identity and VarianceScaling need "scale" factor - WeightInit init = null; - Distribution dist = null; if (kerasInit != null) { if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL()) || kerasInit.equals(conf.getINIT_GLOROT_NORMAL_ALIAS())) { - init = WeightInit.XAVIER; + return WeightInit.XAVIER.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM()) || kerasInit.equals(conf.getINIT_GLOROT_UNIFORM_ALIAS())) { - init = WeightInit.XAVIER_UNIFORM; + return WeightInit.XAVIER_UNIFORM.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_LECUN_NORMAL()) || kerasInit.equals(conf.getINIT_LECUN_NORMAL_ALIAS())) { - init = WeightInit.LECUN_NORMAL; + return WeightInit.LECUN_NORMAL.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM()) || kerasInit.equals(conf.getINIT_LECUN_UNIFORM_ALIAS())) { - init = WeightInit.LECUN_UNIFORM; + return WeightInit.LECUN_UNIFORM.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_HE_NORMAL()) || kerasInit.equals(conf.getINIT_HE_NORMAL_ALIAS())) { - init = WeightInit.RELU; + return WeightInit.RELU.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_HE_UNIFORM()) || kerasInit.equals(conf.getINIT_HE_UNIFORM_ALIAS())) { - init = WeightInit.RELU_UNIFORM; + return WeightInit.RELU_UNIFORM.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_ONE()) || kerasInit.equals(conf.getINIT_ONES()) || kerasInit.equals(conf.getINIT_ONES_ALIAS())) { - init = WeightInit.ONES; + return WeightInit.ONES.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_ZERO()) || kerasInit.equals(conf.getINIT_ZEROS()) || kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) { - init = WeightInit.ZERO; + return WeightInit.ZERO.getWeightInitFunction(); } else if (kerasInit.equals(conf.getINIT_UNIFORM()) || kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) || kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) { if (kerasMajorVersion == 2) { double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL()); double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL()); - dist = new UniformDistribution(minVal, maxVal); + return new WeightInitDistribution(new UniformDistribution(minVal, maxVal)); } else { double scale = 0.05; if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); - dist = new UniformDistribution(-scale, scale); + return new WeightInitDistribution(new UniformDistribution(-scale, scale)); } - init = WeightInit.DISTRIBUTION; } else if (kerasInit.equals(conf.getINIT_NORMAL()) || kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) || kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) { if (kerasMajorVersion == 2) { double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN()); double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV()); - dist = new NormalDistribution(mean, stdDev); + return new WeightInitDistribution(new NormalDistribution(mean, stdDev)); } else { double scale = 0.05; if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); - dist = new NormalDistribution(0, scale); + return new WeightInitDistribution(new NormalDistribution(0, scale)); } - init = WeightInit.DISTRIBUTION; } else if (kerasInit.equals(conf.getINIT_CONSTANT()) || kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) { double value = (double) initConfig.get(conf.getLAYER_FIELD_INIT_VALUE()); - dist = new ConstantDistribution(value); - init = WeightInit.DISTRIBUTION; + return new WeightInitDistribution(new ConstantDistribution(value)); } else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) || kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) { if (kerasMajorVersion == 2) { @@ -121,34 +115,38 @@ public class KerasInitilizationUtils { } catch (Exception e) { gain = (int) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN()); } - dist = new OrthogonalDistribution(gain); + return new WeightInitDistribution(new OrthogonalDistribution(gain)); } else { double scale = 1.1; if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); - dist = new OrthogonalDistribution(scale); + return new WeightInitDistribution(new OrthogonalDistribution(scale)); } - init = WeightInit.DISTRIBUTION; } else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) || kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) { double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN()); double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV()); - dist = new TruncatedNormalDistribution(mean, stdDev); - init = WeightInit.DISTRIBUTION; + return new WeightInitDistribution(new TruncatedNormalDistribution(mean, stdDev)); } else if (kerasInit.equals(conf.getINIT_IDENTITY()) || kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) { if (kerasMajorVersion == 2) { double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN()); - if (gain != 1.) - log.warn("Scaled identity weight init not supported, setting gain=1"); + if (gain != 1.0) + if (gain != 1.0) { + return new WeightInitIdentity(gain); + } else { + return new WeightInitIdentity(); + } } else { double scale = 1.; if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); - if (scale != 1.) - log.warn("Scaled identity weight init not supported, setting scale=1"); + if (scale != 1.0) { + return new WeightInitIdentity(scale); + } else { + return new WeightInitIdentity(); + } } - init = WeightInit.IDENTITY; } else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) { double scale; try { @@ -156,32 +154,27 @@ public class KerasInitilizationUtils { } catch (Exception e) { scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE()); } - if (scale != 1.) - log.warn("Scaled identity weight init not supported, setting scale=1"); String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE()); String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION()); switch (mode) { case "fan_in": if (distribution.equals("normal")) { - init = WeightInit.VAR_SCALING_NORMAL_FAN_IN; + return new WeightInitVarScalingNormalFanIn(scale); } else { - init = WeightInit.VAR_SCALING_UNIFORM_FAN_IN; + return new WeightInitVarScalingUniformFanIn(scale); } - break; case "fan_out": if (distribution.equals("normal")) { - init = WeightInit.VAR_SCALING_NORMAL_FAN_OUT; + return new WeightInitVarScalingNormalFanOut(scale); } else { - init = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT; + return new WeightInitVarScalingUniformFanOut(scale); } - break; case "fan_avg": if (distribution.equals("normal")) { - init = WeightInit.VAR_SCALING_NORMAL_FAN_AVG; + return new WeightInitVarScalingNormalFanAvg(scale); } else { - init = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG; + return new WeightInitVarScalingUniformFanAvg(scale); } - break; default: throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " + "fan_in, fan_out or fan_avg"); @@ -190,7 +183,7 @@ public class KerasInitilizationUtils { throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit); } } - return new Pair<>(init, dist); + throw new IllegalStateException("Error getting Keras weight initialization"); } /** @@ -202,7 +195,7 @@ public class KerasInitilizationUtils { * @throws InvalidKerasConfigurationException Invalid Keras config * @throws UnsupportedKerasConfigurationException Unsupported Keras config */ - public static Pair getWeightInitFromConfig(Map layerConfig, String initField, + public static IWeightInit getWeightInitFromConfig(Map layerConfig, String initField, boolean enforceTrainingConfig, KerasLayerConfiguration conf, int kerasMajorVersion) @@ -225,14 +218,14 @@ public class KerasInitilizationUtils { throw new UnsupportedKerasConfigurationException("Incomplete initialization class"); } } - Pair init; + IWeightInit init; try { init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion); } catch (UnsupportedKerasConfigurationException e) { if (enforceTrainingConfig) throw e; else { - init = new Pair<>(WeightInit.XAVIER, null); + init = new WeightInitXavier(); log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead)."); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java index f752b5b03..b33fda9f4 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java @@ -21,7 +21,6 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java index bd6561d37..27aa340e8 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java @@ -16,7 +16,6 @@ package org.deeplearning4j.nn.modelimport.keras; -import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.nd4j.linalg.learning.regularization.L1Regularization; @@ -25,7 +24,6 @@ import org.nd4j.linalg.learning.regularization.Regularization; import java.util.List; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; public class KerasTestUtils { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java index 5c288b21c..dcfd53518 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java @@ -22,8 +22,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.io.ClassPathResource; -import org.nd4j.linalg.util.Nd4jValidator; import org.nd4j.resources.Resources; import org.nd4j.validation.ValidationResult; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java index f0dfb3694..6043d7d48 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java @@ -21,7 +21,6 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.split.NumberedFileInputSplit; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; - import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; @@ -30,7 +29,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Assert; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java index 6dce1b714..554a2c2d1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java @@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.InputStream; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 162dc235a..81103d315 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -30,11 +30,9 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java index 7072f1956..8ac231e12 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java @@ -25,6 +25,8 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.WeightInitIdentity; +import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanIn; import org.junit.Test; import java.util.HashMap; @@ -94,11 +96,11 @@ public class KerasInitilizationTest extends BaseDL4JTest { WeightInit.RELU_UNIFORM.getWeightInitFunction(), WeightInit.ONES.getWeightInitFunction(), WeightInit.ZERO.getWeightInitFunction(), - WeightInit.IDENTITY.getWeightInitFunction(), + new WeightInitIdentity(0.2), WeightInit.DISTRIBUTION.getWeightInitFunction(new NormalDistribution(mean, stdDev)), WeightInit.DISTRIBUTION.getWeightInitFunction(new OrthogonalDistribution(gain)), WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(value)), - WeightInit.VAR_SCALING_NORMAL_FAN_IN.getWeightInitFunction()}; + new WeightInitVarScalingNormalFanIn(0.2)}; } private Distribution[] dl4jDistributions() { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java index a015dc24f..b5d3c9ab6 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java @@ -17,22 +17,16 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; import lombok.extern.slf4j.Slf4j; -import lombok.val; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; -import java.io.File; import java.io.IOException; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertNotNull; /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java index 31611283f..97ae4318f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java @@ -31,7 +31,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.File; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 874931262..b33ff8d1f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -24,22 +24,19 @@ import org.deeplearning4j.eval.ROCMultiClass; import org.deeplearning4j.gradientcheck.GradientCheckUtil; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; -import org.deeplearning4j.nn.conf.layers.CnnLossLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.layers.recurrent.LSTM; -import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; -import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; -import org.deeplearning4j.nn.modelimport.keras.*; +import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; +import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; +import org.deeplearning4j.nn.modelimport.keras.KerasModel; +import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; -import org.deeplearning4j.nn.workspace.ArrayType; -import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -47,27 +44,25 @@ import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.*; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.resources.Resources; import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.nio.file.Files; import java.nio.file.StandardCopyOption; -import java.util.*; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Unit tests for end-to-end Keras model import. diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java index f1fcc3ded..8bd6e779d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java @@ -21,7 +21,6 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; -import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.transferlearning.TransferLearning; @@ -31,11 +30,8 @@ import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import java.io.File; -import java.nio.file.Files; -import java.nio.file.StandardCopyOption; /** * Import previously stored YOLO9000 Keras net from https://github.com/allanzelener/YAD2K. diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java index 403610c10..dcfe7bfda 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java @@ -26,7 +26,6 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.File; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java index 0a408ac83..eccaeb536 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java @@ -27,16 +27,11 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousC import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.Test; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; -import org.nd4j.linalg.learning.regularization.Regularization; import java.util.HashMap; -import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; /** * @author Max Pumperla diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java index 4737ec128..ff0ba8f3d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java @@ -28,9 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.Test; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; -import org.nd4j.linalg.learning.regularization.Regularization; import java.util.ArrayList; import java.util.HashMap; @@ -39,7 +36,6 @@ import java.util.Map; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; /** * @author Max Pumperla diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java index f356f674f..1676f6136 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java @@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D; import org.junit.Test; -import java.util.ArrayList; import java.util.HashMap; import java.util.Map; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java index 1a6f564b4..6ae3065b6 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java @@ -16,13 +16,11 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; -import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java index a79fab8da..364c50e72 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java @@ -30,15 +30,11 @@ import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.Test; import org.nd4j.base.Preconditions; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; -import org.nd4j.linalg.learning.regularization.Regularization; import java.util.*; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; /** * @author Max Pumperla diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java index 182054900..aec4278e2 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java @@ -17,18 +17,14 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling1D; -import org.deeplearning4j.nn.conf.layers.Upsampling2D; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D; -import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D; import org.junit.Test; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java index 3c7b30b57..cea117f8f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java @@ -17,13 +17,11 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling2D; -import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D; -import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D; import org.junit.Test; import java.util.ArrayList; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java index 779d9ce51..c0a60defd 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java @@ -17,12 +17,10 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer; -import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java index 334ab96d3..cca2515a8 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java @@ -26,16 +26,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.Test; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; -import org.nd4j.linalg.learning.regularization.Regularization; import java.util.HashMap; -import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; /** * @author Max Pumperla diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java index 42cb79cfb..1f2400426 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java @@ -24,10 +24,12 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; -import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.junit.Test; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java index dafafea1d..19d5ce623 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java @@ -24,11 +24,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.*; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java index abeba3da7..b171e063f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java @@ -26,11 +26,7 @@ import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import static org.junit.Assert.assertEquals; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java index f8dc975ea..428d5d99e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java @@ -20,7 +20,6 @@ import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LocallyConnected1D; -import org.deeplearning4j.nn.conf.layers.LocallyConnected2D; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; @@ -31,10 +30,8 @@ import org.junit.Test; import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java index b38b8f783..1ea69e06a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java @@ -27,15 +27,14 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; -import org.nd4j.linalg.learning.regularization.Regularization; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; /** * @author Max Pumperla diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java index cb6a66155..9026c7308 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java @@ -19,7 +19,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer; -import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java index e2d0b7a03..3b82f14ae 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java @@ -33,14 +33,13 @@ import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.Assert; import org.junit.Test; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; -import org.nd4j.linalg.learning.regularization.Regularization; -import java.util.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; /** * @author Max Pumperla diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java index 8819ca9b9..f2a693d9a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java @@ -16,15 +16,12 @@ package org.deeplearning4j.nn.modelimport.keras.optimizers; -import org.deeplearning4j.config.DL4JSystemProperties; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; -import org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.util.DL4JFileUtils; import org.junit.Test; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.File; @@ -32,8 +29,6 @@ import java.io.InputStream; import java.nio.file.Files; import java.nio.file.StandardCopyOption; -import static java.io.File.createTempFile; - public class OptimizerImport extends BaseDL4JTest { @Test diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java index 8753f772c..577e089f9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java @@ -18,9 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; -import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer; import org.junit.Test; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.IOException; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java index f229ec813..45114685b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java @@ -19,15 +19,11 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text; import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.junit.Test; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.IOException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Import Keras Tokenizer diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java index bbcd00372..a4fb6994b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java @@ -20,7 +20,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java index 18cf3305d..7791e3417 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java @@ -29,7 +29,6 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; import java.io.File; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java index d9655a58f..e17535acc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java @@ -16,13 +16,11 @@ package org.deeplearning4j.nn.conf.layers.samediff; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; +import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.autodiff.samediff.SDVariable; @@ -32,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import java.util.Collection; +import java.util.HashMap; import java.util.Map; /** @@ -58,10 +57,12 @@ import java.util.Map; public abstract class SameDiffLayer extends AbstractSameDiffLayer { protected WeightInit weightInit; + protected Map paramWeightInit; protected SameDiffLayer(Builder builder) { super(builder); this.weightInit = builder.weightInit; + this.paramWeightInit = builder.paramWeightInit; } protected SameDiffLayer() { @@ -115,6 +116,7 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer { public static abstract class Builder> extends AbstractSameDiffLayer.Builder { protected WeightInit weightInit = WeightInit.XAVIER; + protected Map paramWeightInit; /** * @param weightInit Weight initialization to use for the layer @@ -123,5 +125,12 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer { this.setWeightInit(weightInit); return (T) this; } + + public T weightInit(@NonNull String param, @NonNull IWeightInit weightInit){ + if(paramWeightInit == null) + paramWeightInit = new HashMap<>(); + paramWeightInit.put(param, weightInit); + return (T) this; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitIdentity.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitIdentity.java index 076fa2ac8..b25121cd3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitIdentity.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitIdentity.java @@ -16,11 +16,14 @@ package org.deeplearning4j.nn.weights; +import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.shade.jackson.annotation.JsonProperty; import java.util.Arrays; @@ -32,9 +35,17 @@ import java.util.Arrays; * * @author Adam Gibson */ -@EqualsAndHashCode +@Data +@NoArgsConstructor public class WeightInitIdentity implements IWeightInit { + private Double scale; + + public WeightInitIdentity(@JsonProperty("scale") Double scale){ + this.scale = scale; + } + + @Override public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { if (shape[0] != shape[1]) { @@ -59,6 +70,11 @@ public class WeightInitIdentity implements IWeightInit { } else { ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0])); } + + if(scale != null){ + ret.muli(scale); + } + INDArray flat = Nd4j.toFlattened(order, ret); paramView.assign(flat); return paramView.reshape(order, shape); @@ -82,13 +98,16 @@ public class WeightInitIdentity implements IWeightInit { indArrayIndices[i] = NDArrayIndex.point(shape[i] / 2); } - paramView.assign(Nd4j.zeros(paramView.shape())); + paramView.assign(0); final INDArray params =paramView.reshape(order, shape); for (int i = 0; i < shape[0]; i++) { indArrayIndices[0] = NDArrayIndex.point(i); indArrayIndices[1] = NDArrayIndex.point(i); params.put(indArrayIndices, Nd4j.ones(1)); } + if(scale != null){ + params.muli(scale); + } return params; } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java index b110bc5a0..17034d408 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.weights; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; import org.nd4j.linalg.api.rng.distribution.Distribution; import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -146,14 +147,13 @@ public class WeightInitUtil { paramView.assign(flat); break; case VAR_SCALING_NORMAL_FAN_IN: - // TODO: needs to be truncated normal to match keras. - Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn)); + Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanIn))); break; case VAR_SCALING_NORMAL_FAN_OUT: - Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut)); + Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanOut))); break; case VAR_SCALING_NORMAL_FAN_AVG: - Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2)); + Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(2.0 / (fanIn + fanOut)))); break; case VAR_SCALING_UNIFORM_FAN_IN: double scalingFanIn = 3.0 / Math.sqrt(fanIn); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanAvg.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanAvg.java index 0be5af0e9..3b9698f10 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanAvg.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanAvg.java @@ -16,22 +16,39 @@ package org.deeplearning4j.nn.weights; +import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; import org.nd4j.linalg.factory.Nd4j; /** - * Gaussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2) + * Truncated aussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2) * * @author Adam Gibson */ -@EqualsAndHashCode +@Data +@NoArgsConstructor public class WeightInitVarScalingNormalFanAvg implements IWeightInit { + private Double scale; + + public WeightInitVarScalingNormalFanAvg(Double scale){ + this.scale = scale; + } + @Override public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { - Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2)); + double std; + if(scale == null){ + std = Math.sqrt(2.0 / (fanIn + fanOut)); + } else { + std = Math.sqrt(2.0 * scale / (fanIn + fanOut)); + } + + Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std)); return paramView.reshape(order, shape); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanIn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanIn.java index 3f89ff015..dca457de3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanIn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanIn.java @@ -16,23 +16,38 @@ package org.deeplearning4j.nn.weights; -import lombok.EqualsAndHashCode; -import org.apache.commons.math3.util.FastMath; +import lombok.Data; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; import org.nd4j.linalg.factory.Nd4j; /** - * Gaussian distribution with mean 0, variance 1.0/(fanIn) + * Gaussian distribution with mean 0, variance {@code 1.0/(fanIn)}
+ * If a scale is provided, use variance {@code scale/(fanIn)} instead * * @author Adam Gibson */ -@EqualsAndHashCode +@Data +@NoArgsConstructor public class WeightInitVarScalingNormalFanIn implements IWeightInit { + private Double scale; + + public WeightInitVarScalingNormalFanIn(Double scale){ + this.scale = scale; + } + @Override public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { - // TODO: needs to be truncated normal to match keras. - Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn)); + double std; + if(scale == null){ + std = Math.sqrt(1.0 / fanIn); + } else { + std = Math.sqrt(scale / fanIn); + } + + Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std)); return paramView.reshape(order, shape); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanOut.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanOut.java index 6369a19c6..0af43ac88 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanOut.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanOut.java @@ -16,22 +16,40 @@ package org.deeplearning4j.nn.weights; +import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; import org.nd4j.linalg.factory.Nd4j; /** - * Gaussian distribution with mean 0, variance 1.0/(fanOut) + * Truncated normal distribution with mean 0, variance 1.0/(fanOut)
+ * If a scale is provided, variance is scale / fanOut * * @author Adam Gibson */ -@EqualsAndHashCode +@Data +@NoArgsConstructor public class WeightInitVarScalingNormalFanOut implements IWeightInit { + private Double scale; + + public WeightInitVarScalingNormalFanOut(Double scale){ + this.scale = scale; + } + @Override public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { - Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut)); + double std; + if(scale == null){ + std = Math.sqrt(1.0 / fanOut); + } else { + std = Math.sqrt(scale / fanOut); + } + + Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std)); return paramView.reshape(order, shape); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanAvg.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanAvg.java index afb1a1dc8..f2e050e6e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanAvg.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanAvg.java @@ -16,7 +16,9 @@ package org.deeplearning4j.nn.weights; +import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -25,12 +27,22 @@ import org.nd4j.linalg.factory.Nd4j; * * @author Adam Gibson */ -@EqualsAndHashCode +@Data +@NoArgsConstructor public class WeightInitVarScalingUniformFanAvg implements IWeightInit { + private Double scale; + + public WeightInitVarScalingUniformFanAvg(Double scale){ + this.scale = scale; + } + @Override public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2); + if(scale != null) + scalingFanAvg *= scale; + Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg)); return paramView.reshape(order, shape); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanIn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanIn.java index 0cf26ecc6..7135394a7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanIn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanIn.java @@ -16,21 +16,34 @@ package org.deeplearning4j.nn.weights; +import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; /** - * Uniform U[-a,a] with a=3.0/(fanIn) + * Uniform U[-a,a] with a=3.0/(fanIn)
+ * If a scale is provided, a = 3.0 * scale / (fanIn) * * @author Adam Gibson */ -@EqualsAndHashCode +@NoArgsConstructor +@Data public class WeightInitVarScalingUniformFanIn implements IWeightInit { + private Double scale; + + public WeightInitVarScalingUniformFanIn(Double scale){ + this.scale = scale; + } + @Override public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { double scalingFanIn = 3.0 / Math.sqrt(fanIn); + if(scale != null) + scalingFanIn *= scale; + Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn)); return paramView.reshape(order, shape); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanOut.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanOut.java index 2d3b116fc..09bf2053d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanOut.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingUniformFanOut.java @@ -16,21 +16,33 @@ package org.deeplearning4j.nn.weights; +import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; /** - * Uniform U[-a,a] with a=3.0/(fanOut) + * Uniform U[-a,a] with a=3.0/(fanOut)
+ * If a scale is provided, a = 3.0 * scale / fanOut * * @author Adam Gibson */ -@EqualsAndHashCode +@Data +@NoArgsConstructor public class WeightInitVarScalingUniformFanOut implements IWeightInit { + private Double scale; + + public WeightInitVarScalingUniformFanOut(Double scale){ + this.scale = scale; + } + @Override public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) { double scalingFanOut = 3.0 / Math.sqrt(fanOut); + if(scale != null) + scalingFanOut *= scale; Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut)); return paramView.reshape(order, shape); } diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index c563eda27..9ce9b46a3 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -25,7 +25,7 @@ elseif (APPLE) elseif(WIN32) set(X86_BUILD true) if (CUDA_BLAS) - set(CMAKE_CXX_FLAGS_RELEASE " /O2 -D_RELEASE=true /wd4804") + set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804") set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305") else() set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true") diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index 51711b3d2..44298ffa2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -3607,6 +3607,13 @@ public class Shape { return ArrayUtil.prodLong(shape); } + public static long lengthOf(int[] shape) { + if (shape.length == 0) + return 1L; + else + return ArrayUtil.prodLong(shape); + } + /** * Calculate the length of the buffer required to store the given shape with the given strides * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java index c33b37565..cc6fba068 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java @@ -28,11 +28,6 @@ import java.util.List; * @author Adam Gibson */ public class SamplingDataSetIterator implements DataSetIterator { - - /** - * - */ - private static final long serialVersionUID = -2700563801361726914L; private DataSet sampleFrom; private int batchSize; private int totalNumberSamples; @@ -145,6 +140,4 @@ public class SamplingDataSetIterator implements DataSetIterator { numTimesSampled++; return ret; } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 25960a8a8..c95dc5ef2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -1164,26 +1164,15 @@ public class Nd4j { * @param type the opType to create * @return the created buffer */ - public static DataBuffer createBuffer(int[] shape, DataType type) { - long length = ArrayUtil.prodLong(shape); - - if (type == DataType.INT) - return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (type == DataType.LONG) - return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (type == DataType.HALF) - return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true, Nd4j.getMemoryManager().getCurrentWorkspace()); - else if (type == DataType.DOUBLE) - return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true, Nd4j.getMemoryManager().getCurrentWorkspace()); - else - return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true, Nd4j.getMemoryManager().getCurrentWorkspace()); + public static DataBuffer createBuffer(@NonNull int[] shape, @NonNull DataType type) { + return createBuffer(ArrayUtil.toLongArray(shape), type); } /** * See {@link #createBuffer(int[], DataType)} */ - public static DataBuffer createBuffer(long[] shape, DataType type) { - long length = ArrayUtil.prodLong(shape); + public static DataBuffer createBuffer(@NonNull long[] shape, @NonNull DataType type) { + long length = Shape.lengthOf(shape); switch (type) { case BOOL: @@ -1229,14 +1218,14 @@ public class Nd4j { * @return the created buffer. */ public static DataBuffer createBufferDetached(int[] shape, DataType type) { - return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type); + return createBufferDetachedImpl( Shape.lengthOf(shape), type); } /** * See {@link #createBufferDetached(int[], DataType)} */ public static DataBuffer createBufferDetached(long[] shape, DataType type) { - return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type); + return createBufferDetachedImpl( Shape.lengthOf(shape), type); } // used by createBufferDetached(long[] DataType) and createBufferDetached(int[] , DataType)