Assorted SameDiff/DL4J fixes (#279)
* #8565 Normalizer toString/hashcode Signed-off-by: Alex Black <blacka101@gmail.com> * #8731 ImagePreProcessingScaler lables/segmentation fix Signed-off-by: Alex Black <blacka101@gmail.com> * #8691 Fix SameDiffLayer/Vertx finetuning and parameter setting support Signed-off-by: Alex Black <blacka101@gmail.com> * #8663 DL4J embedding layer weight init - don't depend on vocab size Signed-off-by: Alex Black <blacka101@gmail.com> * EmbeddingLayer test tweak Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
f116f53d61
commit
483c3d7b8c
|
@ -47,8 +47,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||
|
||||
|
@ -725,4 +724,77 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
|||
assertEquals(new ActivationIdentity(), l2.getActivationFn());
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEmbeddingWeightInit(){
|
||||
// https://github.com/eclipse/deeplearning4j/issues/8663
|
||||
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
|
||||
|
||||
for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, WeightInit.VAR_SCALING_NORMAL_FAN_OUT}) {
|
||||
|
||||
for (boolean seq : new boolean[]{false, true}) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(seq ?
|
||||
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
|
||||
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(seq ?
|
||||
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
|
||||
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
|
||||
.build();
|
||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||
net2.init();
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(seq ?
|
||||
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build() :
|
||||
new EmbeddingLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build())
|
||||
.build();
|
||||
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
|
||||
net3.init();
|
||||
|
||||
INDArray p1 = net.params();
|
||||
INDArray p2 = net2.params();
|
||||
INDArray p3 = net3.params();
|
||||
assertEquals(p1, p2);
|
||||
|
||||
double m1 = p1.meanNumber().doubleValue();
|
||||
double s1 = p1.stdNumber().doubleValue();
|
||||
|
||||
double m3 = p3.meanNumber().doubleValue();
|
||||
double s3 = p3.stdNumber().doubleValue();
|
||||
|
||||
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
|
||||
|
||||
assertEquals(str, m1, m3, 0.1);
|
||||
assertEquals(str, s1, s3, 0.1);
|
||||
|
||||
double re = relErr(s1, s3);
|
||||
assertTrue(str + " - " + re, re < 0.05);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static double relErr(double d1, double d2){
|
||||
if(d1 == 0.0 && d2 == 0.0)
|
||||
return 0.0;
|
||||
return Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
|
@ -136,6 +137,11 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
private class ValidatingSameDiffVertex extends SameDiffVertex {
|
||||
@Override
|
||||
public GraphVertex clone() {
|
||||
return new ValidatingSameDiffVertex();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
||||
return vertexInputs[0];
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.samediff.testlayers;
|
|||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
|
@ -74,4 +75,9 @@ public class SameDiffDenseVertex extends SameDiffVertex {
|
|||
public char paramReshapeOrder(String paramName){
|
||||
return 'f'; //To match DL4J DenseLayer - for easy comparison
|
||||
}
|
||||
|
||||
@Override
|
||||
public GraphVertex clone() {
|
||||
return new SameDiffDenseVertex(nIn, nOut, activation, weightInit);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||
|
||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
|
||||
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.graph.AttentionVertex;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
||||
|
@ -35,6 +36,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
|||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -44,6 +46,9 @@ import org.nd4j.linalg.learning.config.RmsProp;
|
|||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
|
@ -565,4 +570,99 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
|||
assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName));
|
||||
newGraph.output(input);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testTransferLearningSameDiffLayersGraph(){
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
|
||||
.layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
|
||||
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||
.setOutputs("out")
|
||||
.build();
|
||||
|
||||
ComputationGraph cg = new ComputationGraph(conf);
|
||||
cg.init();
|
||||
|
||||
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
INDArray out = cg.output(arr)[0];
|
||||
|
||||
|
||||
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
|
||||
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||
.removeVertexAndConnections("out")
|
||||
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||
.setOutputs("newOut")
|
||||
.build();
|
||||
|
||||
cg2.output(arr);
|
||||
|
||||
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
|
||||
m.put("newOut_W", m.remove("out_W"));
|
||||
m.put("newOut_b", m.remove("out_b"));
|
||||
cg2.setParamTable(m);
|
||||
|
||||
Map<String,INDArray> p1 = cg.paramTable();
|
||||
Map<String,INDArray> p2 = cg2.paramTable();
|
||||
for(String s : p1.keySet()){
|
||||
INDArray i1 = p1.get(s);
|
||||
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
|
||||
assertEquals(s, i1, i2);
|
||||
}
|
||||
|
||||
INDArray out2 = cg2.outputSingle(arr);
|
||||
assertEquals(out, out2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransferLearningSameDiffLayersGraphVertex(){
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
|
||||
.addVertex("l1", new AttentionVertex.Builder().nHeads(1).headSize(5).nInKeys(5).nInQueries(5).nInValues(5).nOut(5).build(), "l0", "l0", "l0")
|
||||
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||
.setOutputs("out")
|
||||
.build();
|
||||
|
||||
ComputationGraph cg = new ComputationGraph(conf);
|
||||
cg.init();
|
||||
|
||||
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
INDArray out = cg.output(arr)[0];
|
||||
|
||||
|
||||
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
|
||||
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||
.removeVertexAndConnections("out")
|
||||
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||
.setOutputs("newOut")
|
||||
.build();
|
||||
|
||||
cg2.output(arr);
|
||||
|
||||
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
|
||||
m.put("newOut_W", m.remove("out_W"));
|
||||
m.put("newOut_b", m.remove("out_b"));
|
||||
cg2.setParamTable(m);
|
||||
|
||||
Map<String,INDArray> p1 = cg.paramTable();
|
||||
Map<String,INDArray> p2 = cg2.paramTable();
|
||||
for(String s : p1.keySet()){
|
||||
INDArray i1 = p1.get(s);
|
||||
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
|
||||
assertEquals(s, i1, i2);
|
||||
}
|
||||
|
||||
INDArray out2 = cg2.outputSingle(arr);
|
||||
assertEquals(out, out2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ import org.deeplearning4j.nn.weights.WeightInitRelu;
|
|||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -48,6 +49,8 @@ import org.nd4j.linalg.learning.config.*;
|
|||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
|
@ -689,4 +692,51 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
|||
assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2));
|
||||
newNet.output(input);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testTransferLearningSameDiffLayers(){
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.DOUBLE)
|
||||
.activation(Activation.TANH)
|
||||
.updater(new Adam(0.01))
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(new LSTM.Builder().nOut(8).build())
|
||||
.layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build())
|
||||
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
|
||||
.layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.setInputType(InputType.recurrent(4))
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 4, 5);
|
||||
INDArray out = net.output(in);
|
||||
|
||||
MultiLayerNetwork net2 = new TransferLearning.Builder(net)
|
||||
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||
.removeLayersFromOutput(1)
|
||||
.addLayer(new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
|
||||
net2.setParam("3_W", net.getParam("3_W"));
|
||||
net2.setParam("3_b", net.getParam("3_b"));
|
||||
|
||||
Map<String,INDArray> p1 = net.paramTable();
|
||||
Map<String,INDArray> p2 = net2.paramTable();
|
||||
for(String s : p1.keySet()){
|
||||
INDArray i1 = p1.get(s);
|
||||
INDArray i2 = p2.get(s);
|
||||
assertEquals(s, i1, i2);
|
||||
}
|
||||
|
||||
INDArray out2 = net2.output(in);
|
||||
|
||||
assertEquals(out, out2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -427,7 +427,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
|||
if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices
|
||||
throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
|
||||
+ ". Disconnected vertices are those that do not connect to either another vertex, and are also"
|
||||
+ " not a network output. To disable this error (i.e., allow network configurations with" +
|
||||
+ " not a network output. This vertex can be set as an output using setOutputs(String...). "
|
||||
+ "To disable this error (i.e., allow network configurations with" +
|
||||
" disconnected vertices) use GraphBuilder.allowDisconnected(true)");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,6 +72,20 @@ public class AttentionVertex extends SameDiffVertex {
|
|||
this.weightInit = builder.weightInit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AttentionVertex clone() {
|
||||
AttentionVertex av = new AttentionVertex();
|
||||
av.nInKeys = nInKeys;
|
||||
av.nInValues = nInValues;
|
||||
av.nInQueries = nInQueries;
|
||||
av.nOut = nOut;
|
||||
av.headSize = headSize;
|
||||
av.nHeads = nHeads;
|
||||
av.projectInput = projectInput;
|
||||
av.weightInit = weightInit;
|
||||
return av;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
||||
InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent) vertexInputs[0];
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
|||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||
|
@ -79,7 +80,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
|
|||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return DefaultParamInitializer.getInstance();
|
||||
return EmbeddingLayerParamInitializer.getInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||
|
@ -92,7 +92,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
|||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return DefaultParamInitializer.getInstance();
|
||||
return EmbeddingLayerParamInitializer.getInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers.samediff;
|
||||
|
||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.*;
|
||||
|
||||
|
||||
|
@ -75,6 +77,15 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
|||
//No op, for lambda vertex
|
||||
}
|
||||
|
||||
@Override
|
||||
public GraphVertex clone() {
|
||||
try {
|
||||
return getClass().getConstructor().newInstance();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Unable to create new instance of class " + getClass().getName() + " from no-arg constructor");
|
||||
}
|
||||
}
|
||||
|
||||
protected VertexInputs getInputs(SameDiff sd) {
|
||||
if (inputs == null) {
|
||||
inputs = new VertexInputs(sd);
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
|
||||
|
@ -36,6 +37,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
|||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -99,11 +101,6 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
|
|||
return vertexParams;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GraphVertex clone() {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public long numParams(boolean backprop) {
|
||||
SDLayerParams params = getVertexParams();
|
||||
|
|
|
@ -3394,7 +3394,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
|
||||
@Override
|
||||
public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
|
||||
Preconditions.checkArgument(paramTable.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal");
|
||||
Map<String,INDArray> m = paramTable();
|
||||
Preconditions.checkArgument(paramTable.keySet().equals(m.keySet()), "Cannot set param table: parameter set keys are not equal");
|
||||
Map<String,INDArray> current = paramTable();
|
||||
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
|
||||
for(String s : current.keySet()){
|
||||
|
|
|
@ -237,9 +237,16 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
|
||||
@Override
|
||||
public void setParams(INDArray params) {
|
||||
if (params != null) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
if(this.params == null && params == null)
|
||||
return;
|
||||
if(this.params == null)
|
||||
throw new IllegalStateException("Cannot set parameters of length " + params.length() + " to a layer with no parameters");
|
||||
if(params == null)
|
||||
throw new IllegalStateException("Cannot set null parameters");
|
||||
|
||||
Preconditions.checkState(this.params.length() == params.length(), "Cannot assign parameter vector of length %s to a layer with %s parameters",
|
||||
params.length(), this.params.length());
|
||||
this.params.assign(params);
|
||||
}
|
||||
|
||||
protected void setParams(INDArray params, char order) {
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.nn.params;
|
||||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* Parameter initializer for EmbeddingLayer and EmbeddingSequenceLayer
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class EmbeddingLayerParamInitializer extends DefaultParamInitializer {
|
||||
|
||||
private static final EmbeddingLayerParamInitializer INSTANCE = new EmbeddingLayerParamInitializer();
|
||||
|
||||
public static EmbeddingLayerParamInitializer getInstance() {
|
||||
return INSTANCE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
protected INDArray createWeightMatrix(long nIn, long nOut, IWeightInit weightInit,
|
||||
INDArray weightParamView, boolean initializeParameters) {
|
||||
val shape = new long[] {nIn, nOut};
|
||||
|
||||
if (initializeParameters) {
|
||||
INDArray ret = weightInit.init(1, //Fan in - note that fanIn=1 for embedding layer... if we used layer nIn (i.e., vocab size) the init would depend on vocab size (which doesn't make sense)
|
||||
nOut, //Fan out
|
||||
shape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
|
||||
return ret;
|
||||
} else {
|
||||
return WeightInitUtil.reshapeWeights(shape, weightParamView);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -79,7 +79,6 @@ public abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats>
|
|||
}
|
||||
|
||||
protected List<S> getFeatureStats() {
|
||||
assertIsFit();
|
||||
return featureStats;
|
||||
}
|
||||
|
||||
|
@ -88,7 +87,6 @@ public abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats>
|
|||
}
|
||||
|
||||
protected List<S> getLabelStats() {
|
||||
assertIsFit();
|
||||
return labelStats;
|
||||
}
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
|||
private double minRange, maxRange;
|
||||
private double maxPixelVal;
|
||||
private int maxBits;
|
||||
private boolean fitLabels = false;
|
||||
|
||||
public ImagePreProcessingScaler() {
|
||||
this(0, 1, 8);
|
||||
|
@ -94,7 +95,10 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
|||
@Override
|
||||
public void preProcess(DataSet toPreProcess) {
|
||||
INDArray features = toPreProcess.getFeatures();
|
||||
this.preProcess(features);
|
||||
preProcess(features);
|
||||
if(fitLabels && toPreProcess.getLabels() != null){
|
||||
preProcess(toPreProcess.getLabels());
|
||||
}
|
||||
}
|
||||
|
||||
public void preProcess(INDArray features) {
|
||||
|
@ -139,6 +143,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
|||
@Override
|
||||
public void revert(DataSet toRevert) {
|
||||
revertFeatures(toRevert.getFeatures());
|
||||
revertLabels(toRevert.getLabels());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -177,10 +182,11 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
|||
@Override
|
||||
public void fitLabel(boolean fitLabels) {
|
||||
//No-op
|
||||
this.fitLabels = fitLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFitLabel() {
|
||||
return false;
|
||||
return fitLabels;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.junit.runners.Parameterized;
|
|||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -156,6 +157,42 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
|
|||
assertEquals(orig, before);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSegmentation(){
|
||||
|
||||
INDArray f = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 3, 16, 16).muli(255));
|
||||
INDArray l = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 10, 8, 8).muli(255));
|
||||
|
||||
ImagePreProcessingScaler s = new ImagePreProcessingScaler();
|
||||
s.fitLabel(true);
|
||||
|
||||
s.fit(new DataSet(f,l));
|
||||
|
||||
INDArray expF = f.div(255);
|
||||
INDArray expL = l.div(255);
|
||||
|
||||
DataSet d = new DataSet(f.dup(), l.dup());
|
||||
s.transform(d);
|
||||
assertEquals(expF, d.getFeatures());
|
||||
assertEquals(expL, d.getLabels());
|
||||
|
||||
|
||||
s.fit(new SingletonDataSetIterator(new DataSet(f, l)));
|
||||
|
||||
INDArray f2 = f.dup();
|
||||
INDArray l2 = l.dup();
|
||||
s.transform(f2);
|
||||
s.transformLabel(l2);
|
||||
assertEquals(expF, f2);
|
||||
assertEquals(expL, l2);
|
||||
|
||||
s.revertFeatures(f2);
|
||||
s.revertLabels(l2);
|
||||
assertEquals(f, f2);
|
||||
assertEquals(l, l2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -22,11 +22,10 @@ import org.junit.runner.RunWith;
|
|||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.*;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -189,11 +188,11 @@ public class NormalizerTests extends BaseNd4jTest {
|
|||
|
||||
INDArray zeros = Nd4j.zeros(shouldBe0_1.shape());
|
||||
|
||||
for (int j = 0; j < 2; j++) {
|
||||
System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(),
|
||||
NDArrayIndex.all()));
|
||||
System.out.println();
|
||||
}
|
||||
// for (int j = 0; j < 2; j++) {
|
||||
// System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(),
|
||||
// NDArrayIndex.all()));
|
||||
// System.out.println();
|
||||
// }
|
||||
|
||||
assertEquals(zeros, shouldBe0_1);
|
||||
assertEquals(zeros, shouldBe0_2);
|
||||
|
@ -218,6 +217,69 @@ public class NormalizerTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizerToStringHashCode(){
|
||||
//https://github.com/eclipse/deeplearning4j/issues/8565
|
||||
|
||||
testNormalizer(new NormalizerMinMaxScaler());
|
||||
NormalizerMinMaxScaler n1 = new NormalizerMinMaxScaler();
|
||||
n1.fitLabel(true);
|
||||
testNormalizer(n1);
|
||||
|
||||
testNormalizer(new NormalizerStandardize());
|
||||
NormalizerStandardize n2 = new NormalizerStandardize();
|
||||
n2.fitLabel(true);
|
||||
testNormalizer(n2);
|
||||
|
||||
testNormalizer(new ImagePreProcessingScaler());
|
||||
ImagePreProcessingScaler n3 = new ImagePreProcessingScaler();
|
||||
n3.fitLabel(true);
|
||||
testNormalizer(n3);
|
||||
|
||||
testNormalizer(new VGG16ImagePreProcessor());
|
||||
VGG16ImagePreProcessor n4 = new VGG16ImagePreProcessor();
|
||||
n4.fitLabel(true);
|
||||
testNormalizer(n4);
|
||||
}
|
||||
|
||||
private static void testNormalizer(DataNormalization n){
|
||||
n.toString();
|
||||
n.hashCode();
|
||||
|
||||
n.fit(new IrisDataSetIterator(30, 150));
|
||||
|
||||
n.toString();
|
||||
n.hashCode();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiNormalizerToStringHashCode(){
|
||||
//https://github.com/eclipse/deeplearning4j/issues/8565
|
||||
|
||||
testMultiNormalizer(new MultiNormalizerMinMaxScaler());
|
||||
MultiNormalizerMinMaxScaler n1 = new MultiNormalizerMinMaxScaler();
|
||||
n1.fitLabel(true);
|
||||
testMultiNormalizer(n1);
|
||||
|
||||
testMultiNormalizer(new MultiNormalizerStandardize());
|
||||
MultiNormalizerStandardize n2 = new MultiNormalizerStandardize();
|
||||
n2.fitLabel(true);
|
||||
testMultiNormalizer(n2);
|
||||
|
||||
testMultiNormalizer(new ImageMultiPreProcessingScaler(0));
|
||||
}
|
||||
|
||||
private static void testMultiNormalizer(MultiDataNormalization n){
|
||||
n.toString();
|
||||
n.hashCode();
|
||||
|
||||
n.fit(new MultiDataSetIteratorAdapter(new IrisDataSetIterator(30, 150)));
|
||||
|
||||
n.toString();
|
||||
n.hashCode();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue