Fixes and pre-release QA (#51)
* #8395 Keras import - support scaled identity weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Keras scaled weight init fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8352 Deprecate duplicate SamplingDataSetIterator class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove /O2 optimization for faster CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak regression test precision for CUDA Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix edge cases for buffer creation Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update MKLDNN validation tests to new helper enable/disable settings Signed-off-by: AlexDBlack <blacka101@gmail.com> * Delete debugging class Signed-off-by: AlexDBlack <blacka101@gmail.com> * MKLDNN test - add proper skip for CUDA backend Signed-off-by: AlexDBlack <blacka101@gmail.com> * Align WeightInitUtil with weight init classes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for SameDiff test layers weight init when using IWeightInit classes Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
1780dcc883
commit
09a827fb6d
|
@ -35,6 +35,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions;
|
|||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -63,6 +64,30 @@ public class LayerHelperValidationUtil {
|
|||
private DataSetIterator data;
|
||||
}
|
||||
|
||||
public static void disableCppHelpers(){
|
||||
try {
|
||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method m = c.getMethod("getInstance");
|
||||
Object instance = m.invoke(null);
|
||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
||||
m2.invoke(instance, false);
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
}
|
||||
|
||||
public static void enableCppHelpers(){
|
||||
try{
|
||||
Class<?> c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment");
|
||||
Method m = c.getMethod("getInstance");
|
||||
Object instance = m.invoke(null);
|
||||
Method m2 = c.getMethod("allowHelpers", boolean.class);
|
||||
m2.invoke(instance, true);
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
}
|
||||
|
||||
public static void validateMLN(MultiLayerNetwork netOrig, TestCase t){
|
||||
assertNotNull(t.getAllowHelpersForClasses());
|
||||
assertFalse(t.getAllowHelpersForClasses().isEmpty());
|
||||
|
@ -95,7 +120,13 @@ public class LayerHelperValidationUtil {
|
|||
for (boolean train : new boolean[]{false, true}) {
|
||||
assertEquals(net1NoHelper.params(), net2With.params());
|
||||
String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
|
||||
List<INDArray> ff1 = net1NoHelper.feedForward(t.getFeatures(), train);
|
||||
List<INDArray> ff1;
|
||||
try {
|
||||
disableCppHelpers();
|
||||
ff1 = net1NoHelper.feedForward(t.getFeatures(), train);
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
List<INDArray> ff2 = net2With.feedForward(t.getFeatures(), train);
|
||||
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());
|
||||
Collections.sort(paramKeys);
|
||||
|
@ -131,7 +162,13 @@ public class LayerHelperValidationUtil {
|
|||
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
||||
}
|
||||
|
||||
INDArray out1 = net1NoHelper.output(t.getFeatures(), train);
|
||||
INDArray out1;
|
||||
try {
|
||||
disableCppHelpers();
|
||||
out1 = net1NoHelper.output(t.getFeatures(), train);
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
INDArray out2 = net2With.output(t.getFeatures(), train);
|
||||
INDArray relError = relError(out1, out2, t.getMinAbsError());
|
||||
double maxRE = relError.maxNumber().doubleValue();
|
||||
|
@ -148,7 +185,13 @@ public class LayerHelperValidationUtil {
|
|||
Preconditions.checkNotNull(t.getLabels(), "Labels are not set (null)");
|
||||
|
||||
log.info("Validation - checking scores");
|
||||
double s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels()));
|
||||
double s1;
|
||||
try {
|
||||
disableCppHelpers();
|
||||
s1 = net1NoHelper.score(new DataSet(t.getFeatures(), t.getLabels()));
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
double s2 = net2With.score(new DataSet(t.getFeatures(), t.getLabels()));
|
||||
|
||||
double re = relError(s1, s2);
|
||||
|
@ -168,7 +211,12 @@ public class LayerHelperValidationUtil {
|
|||
net2With.setInput(t.getFeatures());
|
||||
net2With.setLabels(t.getLabels());
|
||||
|
||||
try {
|
||||
disableCppHelpers();
|
||||
net1NoHelper.computeGradientAndScore();
|
||||
} finally {
|
||||
enableCppHelpers();
|
||||
}
|
||||
net2With.computeGradientAndScore();
|
||||
|
||||
List<String> paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet());
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
package org.deeplearning4j;
|
||||
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import static junit.framework.TestCase.*;
|
||||
|
||||
public class TestBatchNormBp {
|
||||
|
||||
@Test
|
||||
public void test(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
|
||||
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
|
||||
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
|
||||
double e = 1e-5;
|
||||
|
||||
INDArray dLdIn = in.ulike();
|
||||
INDArray dLdm = mean.ulike();
|
||||
INDArray dLdv = var.ulike();
|
||||
INDArray dLdg = gamma.ulike();
|
||||
INDArray dLdb = beta.ulike();
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||
.addInputs(in, mean, var, eps, gamma, beta)
|
||||
.addIntegerArguments(
|
||||
1, //Apply scale
|
||||
1, //Apply beta
|
||||
1) //Axis (NCHW)
|
||||
.addFloatingPointArguments(e)
|
||||
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
System.out.println(dLdIn);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void compareImpls() throws Exception {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
double e = 1e-3;
|
||||
|
||||
INDArray dLdIn = in.ulike();
|
||||
INDArray dLdm = mean.ulike();
|
||||
INDArray dLdv = var.ulike();
|
||||
INDArray dLdg = gamma.ulike();
|
||||
INDArray dLdb = beta.ulike();
|
||||
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||
.list()
|
||||
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||
assertNotNull(bn.getHelper());
|
||||
Field f = bn.getClass().getDeclaredField("helper");
|
||||
f.setAccessible(true);
|
||||
f.set(bn, null);
|
||||
assertNull(bn.getHelper());
|
||||
|
||||
|
||||
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
|
||||
|
||||
net.output(in, true);
|
||||
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
INDArray dldin_dl4j = p.getSecond();
|
||||
|
||||
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
|
||||
}
|
||||
|
||||
}
|
|
@ -70,9 +70,20 @@ public class MinimalSameDiffDense extends SameDiffLayer {
|
|||
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
String b = DefaultParamInitializer.BIAS_KEY;
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(b)){
|
||||
paramWeightInit.get(b).init(nIn, nOut, params.get(b).shape(), 'c', params.get(b));
|
||||
} else {
|
||||
params.get(DefaultParamInitializer.BIAS_KEY).assign(0);
|
||||
}
|
||||
|
||||
String w = DefaultParamInitializer.WEIGHT_KEY;
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(w)){
|
||||
paramWeightInit.get(w).init(nIn, nOut, params.get(w).shape(), 'c', params.get(w));
|
||||
} else {
|
||||
initWeights(nIn, nOut, weightInit, params.get(DefaultParamInitializer.WEIGHT_KEY));
|
||||
}
|
||||
}
|
||||
|
||||
//OPTIONAL methods:
|
||||
// public void setNIn(InputType inputType, boolean override)
|
||||
|
|
|
@ -109,17 +109,21 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
double fanIn = nIn * kernel[0] * kernel[1];
|
||||
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
|
||||
paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
|
||||
} else {
|
||||
if (ConvolutionParamInitializer.BIAS_KEY.equals(e.getKey())) {
|
||||
e.getValue().assign(0);
|
||||
} else {
|
||||
double fanIn = nIn * kernel[0] * kernel[1];
|
||||
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
|
||||
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
|
|
|
@ -88,6 +88,9 @@ public class SameDiffDense extends SameDiffLayer {
|
|||
@Override
|
||||
public void initializeParameters(Map<String,INDArray> params){
|
||||
for(Map.Entry<String,INDArray> e : params.entrySet()){
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
|
||||
paramWeightInit.get(e.getKey()).init(nIn, nOut, e.getValue().shape(), 'c', e.getValue());
|
||||
} else {
|
||||
if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){
|
||||
e.getValue().assign(0.0);
|
||||
} else {
|
||||
|
@ -96,6 +99,7 @@ public class SameDiffDense extends SameDiffLayer {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
|
|
|
@ -50,6 +50,7 @@ import static org.junit.Assume.assumeTrue;
|
|||
|
||||
public class ValidateMKLDNN extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Test
|
||||
public void validateConvSubsampling() throws Exception {
|
||||
//Only run test if using nd4j-native backend
|
||||
|
@ -268,6 +269,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void compareBatchNormBackward() throws Exception {
|
||||
assumeTrue(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native"));
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
|
|
|
@ -339,7 +339,13 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
|||
|
||||
INDArray outAct = net.output(in);
|
||||
|
||||
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
|
||||
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
|
||||
assertEquals(outExp, outAct);
|
||||
} else {
|
||||
boolean eq = outExp.equalsWithEps(outAct, 0.1);
|
||||
assertTrue(eq);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -24,101 +24,11 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A wrapper for a dataset to sample from.
|
||||
* This will randomly sample from the given dataset.
|
||||
* @author Adam GIbson
|
||||
*/
|
||||
public class SamplingDataSetIterator implements DataSetIterator {
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = -2700563801361726914L;
|
||||
private DataSet sampleFrom;
|
||||
private int batchSize;
|
||||
private int totalNumberSamples;
|
||||
private int numTimesSampled;
|
||||
@Getter
|
||||
private DataSetPreProcessor preProcessor;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sampleFrom the dataset to sample from
|
||||
* @param batchSize the batch size to sample
|
||||
* @param totalNumberSamples the sample size
|
||||
* @deprecated Use {@link org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator}
|
||||
*/
|
||||
@Deprecated
|
||||
public class SamplingDataSetIterator extends org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator {
|
||||
public SamplingDataSetIterator(DataSet sampleFrom, int batchSize, int totalNumberSamples) {
|
||||
super();
|
||||
this.sampleFrom = sampleFrom;
|
||||
this.batchSize = batchSize;
|
||||
this.totalNumberSamples = totalNumberSamples;
|
||||
super(sampleFrom, batchSize, totalNumberSamples);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return numTimesSampled < totalNumberSamples;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSet next() {
|
||||
DataSet ret = sampleFrom.sample(batchSize);
|
||||
numTimesSampled += batchSize;
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int inputColumns() {
|
||||
return sampleFrom.numInputs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int totalOutcomes() {
|
||||
return sampleFrom.numOutcomes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean resetSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean asyncSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
numTimesSampled = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int batch() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setPreProcessor(DataSetPreProcessor preProcessor) {
|
||||
this.preProcessor = preProcessor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getLabels() {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public DataSet next(int num) {
|
||||
DataSet ret = sampleFrom.sample(num);
|
||||
numTimesSampled++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.hdf5.*;
|
||||
import org.bytedeco.javacpp.BytePointer;
|
||||
import org.bytedeco.javacpp.FloatPointer;
|
||||
import org.bytedeco.javacpp.Loader;
|
||||
|
@ -32,7 +33,6 @@ import java.lang.Exception;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.bytedeco.hdf5.*;
|
||||
import static org.bytedeco.hdf5.global.hdf5.*;
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.PReLULayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -27,9 +26,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.PReLUParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
@ -79,14 +77,12 @@ public class KerasPReLU extends KerasLayer {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, ALPHA_CONSTRAINT, conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, ALPHA_INIT,
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, ALPHA_INIT,
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
long[] axes = getSharedAxes(layerConfig);
|
||||
|
||||
PReLULayer.Builder builder = new PReLULayer.Builder().sharedAxes(axes)
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution)).name(layerName);
|
||||
.weightInit(init).name(layerName);
|
||||
if (weightConstraint != null){
|
||||
builder.constrainWeights(weightConstraint);
|
||||
}
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -83,15 +81,13 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -84,14 +82,13 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
|
||||
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction())
|
||||
.weightInit(init)
|
||||
.dilation(getDilationRate(layerConfig, 2, conf, true))
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
|
@ -30,7 +29,6 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
|
@ -30,10 +29,9 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -94,15 +92,13 @@ public class KerasConvolution1D extends KerasConvolution {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||
|
|
|
@ -21,14 +21,12 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -87,10 +85,8 @@ public class KerasConvolution2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -100,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
|
|||
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
|
|
|
@ -21,15 +21,13 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -88,10 +86,8 @@ public class KerasConvolution3D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 3, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -101,7 +97,7 @@ public class KerasConvolution3D extends KerasConvolution {
|
|||
Convolution3D.Builder builder = new Convolution3D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 3, conf, kerasMajorVersion))
|
||||
|
|
|
@ -20,14 +20,12 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -86,10 +84,8 @@ public class KerasDeconvolution2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -99,7 +95,7 @@ public class KerasDeconvolution2D extends KerasConvolution {
|
|||
Deconvolution2D.Builder builder = new Deconvolution2D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -30,9 +29,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
|
||||
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
|
@ -126,10 +124,8 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit depthWeightInit = depthWiseInit.getFirst();
|
||||
Distribution depthDistribution = depthWiseInit.getSecond();
|
||||
|
||||
val nIn = getNInFromConfig(previousLayers);
|
||||
|
||||
|
@ -152,7 +148,7 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution {
|
|||
.nIn(nIn)
|
||||
.nOut(nIn * depthMultiplier)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution))
|
||||
.weightInit(depthWiseInit)
|
||||
.depthMultiplier(depthMultiplier)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -20,7 +20,6 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
@ -28,9 +27,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
|
||||
import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -93,17 +91,13 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
|
|||
|
||||
int depthMultiplier = getDepthMultiplier(layerConfig, conf);
|
||||
|
||||
Pair<WeightInit, Distribution> depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
IWeightInit depthWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
conf.getLAYER_FIELD_DEPTH_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit depthWeightInit = depthWiseInit.getFirst();
|
||||
Distribution depthDistribution = depthWiseInit.getSecond();
|
||||
|
||||
Pair<WeightInit, Distribution> pointWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
IWeightInit pointWiseInit = getWeightInitFromConfig(layerConfig,
|
||||
conf.getLAYER_FIELD_POINT_WISE_INIT(), enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit pointWeightInit = pointWiseInit.getFirst();
|
||||
Distribution pointDistribution = pointWiseInit.getSecond();
|
||||
|
||||
if (depthWeightInit != pointWeightInit || depthDistribution != pointDistribution)
|
||||
if ( !depthWiseInit.getClass().equals(pointWiseInit.getClass()) )
|
||||
if (enforceTrainingConfig)
|
||||
throw new UnsupportedKerasConfigurationException(
|
||||
"Specifying different initialization for depth- and point-wise weights not supported.");
|
||||
|
@ -126,7 +120,7 @@ public class KerasSeparableConvolution2D extends KerasConvolution {
|
|||
SeparableConvolution2D.Builder builder = new SeparableConvolution2D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(depthWeightInit.getWeightInitFunction(depthDistribution))
|
||||
.weightInit(depthWiseInit)
|
||||
.depthMultiplier(depthMultiplier)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling3D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -95,15 +93,13 @@ public class KerasDense extends KerasLayer {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
DenseLayer.Builder builder = new DenseLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout).activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.hasBias(hasBias);
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.deeplearning4j.nn.conf.InputPreProcessor;
|
|||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
|||
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -26,7 +25,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -30,11 +29,10 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -106,10 +104,8 @@ public class KerasEmbedding extends KerasLayer {
|
|||
"in DL4J, apply masking as a pre-processing step to your input." +
|
||||
"See http://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent#masking for more on this.");
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -121,7 +117,7 @@ public class KerasEmbedding extends KerasLayer {
|
|||
.inferInputLength(inferInputLength)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout).activation(Activation.IDENTITY)
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInit(init)
|
||||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization)
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -90,11 +88,8 @@ public class KerasLocallyConnected1D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 1, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
// TODO: take care of distribution and bias init
|
||||
//Distribution distribution = init.getSecond();
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -104,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution {
|
|||
LocallyConnected1D.Builder builder = new LocallyConnected1D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit)
|
||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
|
@ -29,9 +28,8 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -39,9 +37,7 @@ import java.util.Map;
|
|||
import static org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils.*;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils.getActivationFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils.getWeightInitFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getHasBiasFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getNOutFromConfig;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.removeDefaultWeights;
|
||||
import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.*;
|
||||
|
||||
|
||||
/**
|
||||
|
@ -92,11 +88,9 @@ public class KerasLocallyConnected2D extends KerasConvolution {
|
|||
numTrainableParams = hasBias ? 2 : 1;
|
||||
int[] dilationRate = getDilationRate(layerConfig, 2, conf, false);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
// TODO: take care of distribution and bias init
|
||||
//Distribution distribution = init.getSecond();
|
||||
// TODO: take care of bias init
|
||||
|
||||
LayerConstraint biasConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_B_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
@ -106,7 +100,7 @@ public class KerasLocallyConnected2D extends KerasConvolution {
|
|||
LocallyConnected2D.Builder builder = new LocallyConnected2D.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit)
|
||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
|
|
@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
|
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -151,15 +150,11 @@ public class KerasLSTM extends KerasLayer {
|
|||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
super(layerConfig, enforceTrainingConfig);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit recurrentWeightInit = recurrentInit.getFirst();
|
||||
Distribution recurrentDistribution = recurrentInit.getSecond();
|
||||
|
||||
boolean hasBias = getHasBiasFromConfig(layerConfig, conf);
|
||||
|
||||
|
@ -186,8 +181,8 @@ public class KerasLSTM extends KerasLayer {
|
|||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution))
|
||||
.weightInit(init)
|
||||
.weightInitRecurrent(recurrentInit)
|
||||
.biasInit(0.0) // TODO: this is incorrect
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
|
@ -34,7 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
|
@ -124,15 +123,11 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
super(layerConfig, enforceTrainingConfig);
|
||||
|
||||
Pair<WeightInit, Distribution> init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit weightInit = init.getFirst();
|
||||
Distribution distribution = init.getSecond();
|
||||
|
||||
Pair<WeightInit, Distribution> recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
IWeightInit recurrentInit = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INNER_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
WeightInit recurrentWeightInit = recurrentInit.getFirst();
|
||||
Distribution recurrentDistribution = recurrentInit.getSecond();
|
||||
|
||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||
this.returnSequences = (Boolean) innerConfig.get(conf.getLAYER_FIELD_RETURN_SEQUENCES());
|
||||
|
@ -154,8 +149,8 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
.nOut(getNOutFromConfig(layerConfig, conf))
|
||||
.dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(weightInit.getWeightInitFunction(distribution))
|
||||
.weightInitRecurrent(recurrentWeightInit.getWeightInitFunction(recurrentDistribution))
|
||||
.weightInit(init)
|
||||
.weightInitRecurrent(recurrentInit)
|
||||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
|
|
|
@ -20,9 +20,7 @@ import com.google.gson.Gson;
|
|||
import com.google.gson.reflect.TypeToken;
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -31,7 +29,6 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -22,9 +22,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
|||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
/**
|
||||
|
|
|
@ -19,17 +19,15 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessors;
|
|||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
|
|
|
@ -1,28 +1,15 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.utils;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.validation.Nd4jCommonValidator;
|
||||
import org.nd4j.validation.ValidationResult;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
|
||||
/**
|
||||
* A utility for validating serialized Keras sequential and functional models for import into DL4J
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -21,8 +21,7 @@ import org.deeplearning4j.nn.conf.distribution.*;
|
|||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.weights.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -42,7 +41,7 @@ public class KerasInitilizationUtils {
|
|||
* @return DL4J weight initialization enum
|
||||
* @see WeightInit
|
||||
*/
|
||||
public static Pair<WeightInit, Distribution> mapWeightInitialization(String kerasInit,
|
||||
public static IWeightInit mapWeightInitialization(String kerasInit,
|
||||
KerasLayerConfiguration conf,
|
||||
Map<String, Object> initConfig,
|
||||
int kerasMajorVersion)
|
||||
|
@ -50,68 +49,63 @@ public class KerasInitilizationUtils {
|
|||
|
||||
|
||||
// TODO: Identity and VarianceScaling need "scale" factor
|
||||
WeightInit init = null;
|
||||
Distribution dist = null;
|
||||
if (kerasInit != null) {
|
||||
if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_GLOROT_NORMAL_ALIAS())) {
|
||||
init = WeightInit.XAVIER;
|
||||
return WeightInit.XAVIER.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_GLOROT_UNIFORM_ALIAS())) {
|
||||
init = WeightInit.XAVIER_UNIFORM;
|
||||
return WeightInit.XAVIER_UNIFORM.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_LECUN_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_LECUN_NORMAL_ALIAS())) {
|
||||
init = WeightInit.LECUN_NORMAL;
|
||||
return WeightInit.LECUN_NORMAL.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_LECUN_UNIFORM_ALIAS())) {
|
||||
init = WeightInit.LECUN_UNIFORM;
|
||||
return WeightInit.LECUN_UNIFORM.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_HE_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_HE_NORMAL_ALIAS())) {
|
||||
init = WeightInit.RELU;
|
||||
return WeightInit.RELU.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_HE_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_HE_UNIFORM_ALIAS())) {
|
||||
init = WeightInit.RELU_UNIFORM;
|
||||
return WeightInit.RELU_UNIFORM.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_ONE()) ||
|
||||
kerasInit.equals(conf.getINIT_ONES()) ||
|
||||
kerasInit.equals(conf.getINIT_ONES_ALIAS())) {
|
||||
init = WeightInit.ONES;
|
||||
return WeightInit.ONES.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_ZERO()) ||
|
||||
kerasInit.equals(conf.getINIT_ZEROS()) ||
|
||||
kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) {
|
||||
init = WeightInit.ZERO;
|
||||
return WeightInit.ZERO.getWeightInitFunction();
|
||||
} else if (kerasInit.equals(conf.getINIT_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
double minVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL());
|
||||
double maxVal = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL());
|
||||
dist = new UniformDistribution(minVal, maxVal);
|
||||
return new WeightInitDistribution(new UniformDistribution(minVal, maxVal));
|
||||
} else {
|
||||
double scale = 0.05;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
dist = new UniformDistribution(-scale, scale);
|
||||
return new WeightInitDistribution(new UniformDistribution(-scale, scale));
|
||||
}
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
} else if (kerasInit.equals(conf.getINIT_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
|
||||
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
|
||||
dist = new NormalDistribution(mean, stdDev);
|
||||
return new WeightInitDistribution(new NormalDistribution(mean, stdDev));
|
||||
} else {
|
||||
double scale = 0.05;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
dist = new NormalDistribution(0, scale);
|
||||
return new WeightInitDistribution(new NormalDistribution(0, scale));
|
||||
}
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
} else if (kerasInit.equals(conf.getINIT_CONSTANT()) ||
|
||||
kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
|
||||
double value = (double) initConfig.get(conf.getLAYER_FIELD_INIT_VALUE());
|
||||
dist = new ConstantDistribution(value);
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
return new WeightInitDistribution(new ConstantDistribution(value));
|
||||
} else if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) ||
|
||||
kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
|
@ -121,34 +115,38 @@ public class KerasInitilizationUtils {
|
|||
} catch (Exception e) {
|
||||
gain = (int) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
|
||||
}
|
||||
dist = new OrthogonalDistribution(gain);
|
||||
return new WeightInitDistribution(new OrthogonalDistribution(gain));
|
||||
} else {
|
||||
double scale = 1.1;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
dist = new OrthogonalDistribution(scale);
|
||||
return new WeightInitDistribution(new OrthogonalDistribution(scale));
|
||||
}
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
} else if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) ||
|
||||
kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
|
||||
double mean = (double) initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
|
||||
double stdDev = (double) initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
|
||||
dist = new TruncatedNormalDistribution(mean, stdDev);
|
||||
init = WeightInit.DISTRIBUTION;
|
||||
return new WeightInitDistribution(new TruncatedNormalDistribution(mean, stdDev));
|
||||
} else if (kerasInit.equals(conf.getINIT_IDENTITY()) ||
|
||||
kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
|
||||
if (kerasMajorVersion == 2) {
|
||||
double gain = (double) initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
|
||||
if (gain != 1.)
|
||||
log.warn("Scaled identity weight init not supported, setting gain=1");
|
||||
if (gain != 1.0)
|
||||
if (gain != 1.0) {
|
||||
return new WeightInitIdentity(gain);
|
||||
} else {
|
||||
return new WeightInitIdentity();
|
||||
}
|
||||
} else {
|
||||
double scale = 1.;
|
||||
if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE()))
|
||||
scale = (double) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
if (scale != 1.)
|
||||
log.warn("Scaled identity weight init not supported, setting scale=1");
|
||||
if (scale != 1.0) {
|
||||
return new WeightInitIdentity(scale);
|
||||
} else {
|
||||
return new WeightInitIdentity();
|
||||
}
|
||||
}
|
||||
init = WeightInit.IDENTITY;
|
||||
} else if (kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) {
|
||||
double scale;
|
||||
try {
|
||||
|
@ -156,32 +154,27 @@ public class KerasInitilizationUtils {
|
|||
} catch (Exception e) {
|
||||
scale = (int) initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
|
||||
}
|
||||
if (scale != 1.)
|
||||
log.warn("Scaled identity weight init not supported, setting scale=1");
|
||||
String mode = (String) initConfig.get(conf.getLAYER_FIELD_INIT_MODE());
|
||||
String distribution = (String) initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION());
|
||||
switch (mode) {
|
||||
case "fan_in":
|
||||
if (distribution.equals("normal")) {
|
||||
init = WeightInit.VAR_SCALING_NORMAL_FAN_IN;
|
||||
return new WeightInitVarScalingNormalFanIn(scale);
|
||||
} else {
|
||||
init = WeightInit.VAR_SCALING_UNIFORM_FAN_IN;
|
||||
return new WeightInitVarScalingUniformFanIn(scale);
|
||||
}
|
||||
break;
|
||||
case "fan_out":
|
||||
if (distribution.equals("normal")) {
|
||||
init = WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
|
||||
return new WeightInitVarScalingNormalFanOut(scale);
|
||||
} else {
|
||||
init = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT;
|
||||
return new WeightInitVarScalingUniformFanOut(scale);
|
||||
}
|
||||
break;
|
||||
case "fan_avg":
|
||||
if (distribution.equals("normal")) {
|
||||
init = WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
|
||||
return new WeightInitVarScalingNormalFanAvg(scale);
|
||||
} else {
|
||||
init = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG;
|
||||
return new WeightInitVarScalingUniformFanAvg(scale);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either " +
|
||||
"fan_in, fan_out or fan_avg");
|
||||
|
@ -190,7 +183,7 @@ public class KerasInitilizationUtils {
|
|||
throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
|
||||
}
|
||||
}
|
||||
return new Pair<>(init, dist);
|
||||
throw new IllegalStateException("Error getting Keras weight initialization");
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -202,7 +195,7 @@ public class KerasInitilizationUtils {
|
|||
* @throws InvalidKerasConfigurationException Invalid Keras config
|
||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
|
||||
*/
|
||||
public static Pair<WeightInit, Distribution> getWeightInitFromConfig(Map<String, Object> layerConfig, String initField,
|
||||
public static IWeightInit getWeightInitFromConfig(Map<String, Object> layerConfig, String initField,
|
||||
boolean enforceTrainingConfig,
|
||||
KerasLayerConfiguration conf,
|
||||
int kerasMajorVersion)
|
||||
|
@ -225,14 +218,14 @@ public class KerasInitilizationUtils {
|
|||
throw new UnsupportedKerasConfigurationException("Incomplete initialization class");
|
||||
}
|
||||
}
|
||||
Pair<WeightInit, Distribution> init;
|
||||
IWeightInit init;
|
||||
try {
|
||||
init = mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion);
|
||||
} catch (UnsupportedKerasConfigurationException e) {
|
||||
if (enforceTrainingConfig)
|
||||
throw e;
|
||||
else {
|
||||
init = new Pair<>(WeightInit.XAVIER, null);
|
||||
init = new WeightInitXavier();
|
||||
log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
|
@ -25,7 +24,6 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
public class KerasTestUtils {
|
||||
|
|
|
@ -22,8 +22,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.util.Nd4jValidator;
|
||||
import org.nd4j.resources.Resources;
|
||||
import org.nd4j.validation.ValidationResult;
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.datavec.api.records.reader.SequenceRecordReader;
|
|||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||
import org.datavec.api.split.NumberedFileInputSplit;
|
||||
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
|
||||
|
||||
import org.deeplearning4j.nn.layers.recurrent.LSTM;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
|
@ -30,7 +29,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
|||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
|
|
@ -30,11 +30,9 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Arrays;
|
||||
|
|
|
@ -25,6 +25,8 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitIdentity;
|
||||
import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanIn;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
@ -94,11 +96,11 @@ public class KerasInitilizationTest extends BaseDL4JTest {
|
|||
WeightInit.RELU_UNIFORM.getWeightInitFunction(),
|
||||
WeightInit.ONES.getWeightInitFunction(),
|
||||
WeightInit.ZERO.getWeightInitFunction(),
|
||||
WeightInit.IDENTITY.getWeightInitFunction(),
|
||||
new WeightInitIdentity(0.2),
|
||||
WeightInit.DISTRIBUTION.getWeightInitFunction(new NormalDistribution(mean, stdDev)),
|
||||
WeightInit.DISTRIBUTION.getWeightInitFunction(new OrthogonalDistribution(gain)),
|
||||
WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(value)),
|
||||
WeightInit.VAR_SCALING_NORMAL_FAN_IN.getWeightInitFunction()};
|
||||
new WeightInitVarScalingNormalFanIn(0.2)};
|
||||
}
|
||||
|
||||
private Distribution[] dl4jDistributions() {
|
||||
|
|
|
@ -17,22 +17,16 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.configurations;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
|
|
@ -24,22 +24,19 @@ import org.deeplearning4j.eval.ROCMultiClass;
|
|||
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LSTM;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
|
||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
|
||||
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
@ -47,27 +44,25 @@ import org.junit.rules.TemporaryFolder;
|
|||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URL;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Unit tests for end-to-end Keras model import.
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
|||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
|
||||
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||||
|
@ -31,11 +30,8 @@ import org.junit.Test;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
|
||||
/**
|
||||
* Import previously stored YOLO9000 Keras net from https://github.com/allanzelener/YAD2K.
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.junit.Ignore;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
|
|
@ -27,16 +27,11 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousC
|
|||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -28,9 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu
|
|||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
@ -39,7 +36,6 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -16,13 +16,11 @@
|
|||
|
||||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D;
|
||||
import org.junit.Test;
|
||||
|
||||
|
|
|
@ -30,15 +30,11 @@ import org.deeplearning4j.nn.weights.IWeightInit;
|
|||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -17,18 +17,14 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling1D;
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
|
|
@ -17,13 +17,11 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
|
|
@ -17,12 +17,10 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D;
|
||||
import org.junit.Test;
|
||||
|
||||
|
|
|
@ -26,16 +26,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -24,10 +24,12 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
|
|
@ -24,11 +24,11 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
|
|
@ -26,11 +26,7 @@ import org.junit.Test;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
|
|||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected1D;
|
||||
import org.deeplearning4j.nn.conf.layers.LocallyConnected2D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
|
@ -31,10 +30,8 @@ import org.junit.Test;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
|
|
|
@ -27,15 +27,14 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
|||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers.pooling;
|
|||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
|
|
|
@ -33,14 +33,13 @@ import org.deeplearning4j.nn.weights.IWeightInit;
|
|||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
/**
|
||||
* @author Max Pumperla
|
||||
|
|
|
@ -16,15 +16,12 @@
|
|||
|
||||
package org.deeplearning4j.nn.modelimport.keras.optimizers;
|
||||
|
||||
import org.deeplearning4j.config.DL4JSystemProperties;
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
||||
import org.deeplearning4j.util.DL4JFileUtils;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -32,8 +29,6 @@ import java.io.InputStream;
|
|||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
|
||||
import static java.io.File.createTempFile;
|
||||
|
||||
public class OptimizerImport extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -18,9 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence;
|
|||
|
||||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.IOException;
|
||||
|
|
|
@ -19,15 +19,11 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text;
|
|||
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Import Keras Tokenizer
|
||||
|
|
|
@ -20,7 +20,6 @@ import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.junit.Test;
|
|||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
|
|
|
@ -16,13 +16,11 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers.samediff;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -32,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
@ -58,10 +57,12 @@ import java.util.Map;
|
|||
public abstract class SameDiffLayer extends AbstractSameDiffLayer {
|
||||
|
||||
protected WeightInit weightInit;
|
||||
protected Map<String,IWeightInit> paramWeightInit;
|
||||
|
||||
protected SameDiffLayer(Builder builder) {
|
||||
super(builder);
|
||||
this.weightInit = builder.weightInit;
|
||||
this.paramWeightInit = builder.paramWeightInit;
|
||||
}
|
||||
|
||||
protected SameDiffLayer() {
|
||||
|
@ -115,6 +116,7 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
|
|||
public static abstract class Builder<T extends Builder<T>> extends AbstractSameDiffLayer.Builder<T> {
|
||||
|
||||
protected WeightInit weightInit = WeightInit.XAVIER;
|
||||
protected Map<String,IWeightInit> paramWeightInit;
|
||||
|
||||
/**
|
||||
* @param weightInit Weight initialization to use for the layer
|
||||
|
@ -123,5 +125,12 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
|
|||
this.setWeightInit(weightInit);
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
public T weightInit(@NonNull String param, @NonNull IWeightInit weightInit){
|
||||
if(paramWeightInit == null)
|
||||
paramWeightInit = new HashMap<>();
|
||||
paramWeightInit.put(param, weightInit);
|
||||
return (T) this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,11 +16,14 @@
|
|||
|
||||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
|
@ -32,9 +35,17 @@ import java.util.Arrays;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class WeightInitIdentity implements IWeightInit {
|
||||
|
||||
private Double scale;
|
||||
|
||||
public WeightInitIdentity(@JsonProperty("scale") Double scale){
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||
if (shape[0] != shape[1]) {
|
||||
|
@ -59,6 +70,11 @@ public class WeightInitIdentity implements IWeightInit {
|
|||
} else {
|
||||
ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0]));
|
||||
}
|
||||
|
||||
if(scale != null){
|
||||
ret.muli(scale);
|
||||
}
|
||||
|
||||
INDArray flat = Nd4j.toFlattened(order, ret);
|
||||
paramView.assign(flat);
|
||||
return paramView.reshape(order, shape);
|
||||
|
@ -82,13 +98,16 @@ public class WeightInitIdentity implements IWeightInit {
|
|||
indArrayIndices[i] = NDArrayIndex.point(shape[i] / 2);
|
||||
}
|
||||
|
||||
paramView.assign(Nd4j.zeros(paramView.shape()));
|
||||
paramView.assign(0);
|
||||
final INDArray params =paramView.reshape(order, shape);
|
||||
for (int i = 0; i < shape[0]; i++) {
|
||||
indArrayIndices[0] = NDArrayIndex.point(i);
|
||||
indArrayIndices[1] = NDArrayIndex.point(i);
|
||||
params.put(indArrayIndices, Nd4j.ones(1));
|
||||
}
|
||||
if(scale != null){
|
||||
params.muli(scale);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.weights;
|
|||
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
|
||||
import org.nd4j.linalg.api.rng.distribution.Distribution;
|
||||
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -146,14 +147,13 @@ public class WeightInitUtil {
|
|||
paramView.assign(flat);
|
||||
break;
|
||||
case VAR_SCALING_NORMAL_FAN_IN:
|
||||
// TODO: needs to be truncated normal to match keras.
|
||||
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
|
||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanIn)));
|
||||
break;
|
||||
case VAR_SCALING_NORMAL_FAN_OUT:
|
||||
Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut));
|
||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanOut)));
|
||||
break;
|
||||
case VAR_SCALING_NORMAL_FAN_AVG:
|
||||
Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2));
|
||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(2.0 / (fanIn + fanOut))));
|
||||
break;
|
||||
case VAR_SCALING_UNIFORM_FAN_IN:
|
||||
double scalingFanIn = 3.0 / Math.sqrt(fanIn);
|
||||
|
|
|
@ -16,22 +16,39 @@
|
|||
|
||||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Gaussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2)
|
||||
* Truncated aussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2)
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class WeightInitVarScalingNormalFanAvg implements IWeightInit {
|
||||
|
||||
private Double scale;
|
||||
|
||||
public WeightInitVarScalingNormalFanAvg(Double scale){
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||
Nd4j.randn(paramView).divi(FastMath.sqrt((fanIn + fanOut) / 2));
|
||||
double std;
|
||||
if(scale == null){
|
||||
std = Math.sqrt(2.0 / (fanIn + fanOut));
|
||||
} else {
|
||||
std = Math.sqrt(2.0 * scale / (fanIn + fanOut));
|
||||
}
|
||||
|
||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,23 +16,38 @@
|
|||
|
||||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Gaussian distribution with mean 0, variance 1.0/(fanIn)
|
||||
* Gaussian distribution with mean 0, variance {@code 1.0/(fanIn)}<br>
|
||||
* If a scale is provided, use variance {@code scale/(fanIn)} instead
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class WeightInitVarScalingNormalFanIn implements IWeightInit {
|
||||
|
||||
private Double scale;
|
||||
|
||||
public WeightInitVarScalingNormalFanIn(Double scale){
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||
// TODO: needs to be truncated normal to match keras.
|
||||
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
|
||||
double std;
|
||||
if(scale == null){
|
||||
std = Math.sqrt(1.0 / fanIn);
|
||||
} else {
|
||||
std = Math.sqrt(scale / fanIn);
|
||||
}
|
||||
|
||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,22 +16,40 @@
|
|||
|
||||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Gaussian distribution with mean 0, variance 1.0/(fanOut)
|
||||
* Truncated normal distribution with mean 0, variance 1.0/(fanOut)<br>
|
||||
* If a scale is provided, variance is scale / fanOut
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class WeightInitVarScalingNormalFanOut implements IWeightInit {
|
||||
|
||||
private Double scale;
|
||||
|
||||
public WeightInitVarScalingNormalFanOut(Double scale){
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||
Nd4j.randn(paramView).divi(FastMath.sqrt(fanOut));
|
||||
double std;
|
||||
if(scale == null){
|
||||
std = Math.sqrt(1.0 / fanOut);
|
||||
} else {
|
||||
std = Math.sqrt(scale / fanOut);
|
||||
}
|
||||
|
||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -25,12 +27,22 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class WeightInitVarScalingUniformFanAvg implements IWeightInit {
|
||||
|
||||
private Double scale;
|
||||
|
||||
public WeightInitVarScalingUniformFanAvg(Double scale){
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||
double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
|
||||
if(scale != null)
|
||||
scalingFanAvg *= scale;
|
||||
|
||||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
|
|
@ -16,21 +16,34 @@
|
|||
|
||||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Uniform U[-a,a] with a=3.0/(fanIn)
|
||||
* Uniform U[-a,a] with a=3.0/(fanIn)<br>
|
||||
* If a scale is provided, a = 3.0 * scale / (fanIn)
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public class WeightInitVarScalingUniformFanIn implements IWeightInit {
|
||||
|
||||
private Double scale;
|
||||
|
||||
public WeightInitVarScalingUniformFanIn(Double scale){
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||
double scalingFanIn = 3.0 / Math.sqrt(fanIn);
|
||||
if(scale != null)
|
||||
scalingFanIn *= scale;
|
||||
|
||||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
|
|
@ -16,21 +16,33 @@
|
|||
|
||||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Uniform U[-a,a] with a=3.0/(fanOut)
|
||||
* Uniform U[-a,a] with a=3.0/(fanOut)<br>
|
||||
* If a scale is provided, a = 3.0 * scale / fanOut
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class WeightInitVarScalingUniformFanOut implements IWeightInit {
|
||||
|
||||
private Double scale;
|
||||
|
||||
public WeightInitVarScalingUniformFanOut(Double scale){
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||
double scalingFanOut = 3.0 / Math.sqrt(fanOut);
|
||||
if(scale != null)
|
||||
scalingFanOut *= scale;
|
||||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ elseif (APPLE)
|
|||
elseif(WIN32)
|
||||
set(X86_BUILD true)
|
||||
if (CUDA_BLAS)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE " /O2 -D_RELEASE=true /wd4804")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
|
||||
|
|
|
@ -3607,6 +3607,13 @@ public class Shape {
|
|||
return ArrayUtil.prodLong(shape);
|
||||
}
|
||||
|
||||
public static long lengthOf(int[] shape) {
|
||||
if (shape.length == 0)
|
||||
return 1L;
|
||||
else
|
||||
return ArrayUtil.prodLong(shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the length of the buffer required to store the given shape with the given strides
|
||||
*
|
||||
|
|
|
@ -28,11 +28,6 @@ import java.util.List;
|
|||
* @author Adam Gibson
|
||||
*/
|
||||
public class SamplingDataSetIterator implements DataSetIterator {
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = -2700563801361726914L;
|
||||
private DataSet sampleFrom;
|
||||
private int batchSize;
|
||||
private int totalNumberSamples;
|
||||
|
@ -145,6 +140,4 @@ public class SamplingDataSetIterator implements DataSetIterator {
|
|||
numTimesSampled++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -1164,26 +1164,15 @@ public class Nd4j {
|
|||
* @param type the opType to create
|
||||
* @return the created buffer
|
||||
*/
|
||||
public static DataBuffer createBuffer(int[] shape, DataType type) {
|
||||
long length = ArrayUtil.prodLong(shape);
|
||||
|
||||
if (type == DataType.INT)
|
||||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createInt(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (type == DataType.LONG)
|
||||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createLong(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (type == DataType.HALF)
|
||||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createHalf(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (type == DataType.DOUBLE)
|
||||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else
|
||||
return Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true) : DATA_BUFFER_FACTORY_INSTANCE.createFloat(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
public static DataBuffer createBuffer(@NonNull int[] shape, @NonNull DataType type) {
|
||||
return createBuffer(ArrayUtil.toLongArray(shape), type);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createBuffer(int[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createBuffer(long[] shape, DataType type) {
|
||||
long length = ArrayUtil.prodLong(shape);
|
||||
public static DataBuffer createBuffer(@NonNull long[] shape, @NonNull DataType type) {
|
||||
long length = Shape.lengthOf(shape);
|
||||
|
||||
switch (type) {
|
||||
case BOOL:
|
||||
|
@ -1229,14 +1218,14 @@ public class Nd4j {
|
|||
* @return the created buffer.
|
||||
*/
|
||||
public static DataBuffer createBufferDetached(int[] shape, DataType type) {
|
||||
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
|
||||
return createBufferDetachedImpl( Shape.lengthOf(shape), type);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createBufferDetached(int[], DataType)}
|
||||
*/
|
||||
public static DataBuffer createBufferDetached(long[] shape, DataType type) {
|
||||
return createBufferDetachedImpl( ArrayUtil.prodLong(shape), type);
|
||||
return createBufferDetachedImpl( Shape.lengthOf(shape), type);
|
||||
}
|
||||
|
||||
// used by createBufferDetached(long[] DataType) and createBufferDetached(int[] , DataType)
|
||||
|
|
Loading…
Reference in New Issue