Assorted SameDiff/DL4J fixes (#279)
* #8565 Normalizer toString/hashcode Signed-off-by: Alex Black <blacka101@gmail.com> * #8731 ImagePreProcessingScaler lables/segmentation fix Signed-off-by: Alex Black <blacka101@gmail.com> * #8691 Fix SameDiffLayer/Vertx finetuning and parameter setting support Signed-off-by: Alex Black <blacka101@gmail.com> * #8663 DL4J embedding layer weight init - don't depend on vocab size Signed-off-by: Alex Black <blacka101@gmail.com> * EmbeddingLayer test tweak Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
f116f53d61
commit
483c3d7b8c
|
@ -47,8 +47,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class EmbeddingLayerTest extends BaseDL4JTest {
|
public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -725,4 +724,77 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
assertEquals(new ActivationIdentity(), l2.getActivationFn());
|
assertEquals(new ActivationIdentity(), l2.getActivationFn());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmbeddingWeightInit(){
|
||||||
|
// https://github.com/eclipse/deeplearning4j/issues/8663
|
||||||
|
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
|
||||||
|
|
||||||
|
for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, WeightInit.VAR_SCALING_NORMAL_FAN_OUT}) {
|
||||||
|
|
||||||
|
for (boolean seq : new boolean[]{false, true}) {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(seq ?
|
||||||
|
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
|
||||||
|
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(seq ?
|
||||||
|
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
|
||||||
|
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
|
net2.init();
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(seq ?
|
||||||
|
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build() :
|
||||||
|
new EmbeddingLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
|
||||||
|
net3.init();
|
||||||
|
|
||||||
|
INDArray p1 = net.params();
|
||||||
|
INDArray p2 = net2.params();
|
||||||
|
INDArray p3 = net3.params();
|
||||||
|
assertEquals(p1, p2);
|
||||||
|
|
||||||
|
double m1 = p1.meanNumber().doubleValue();
|
||||||
|
double s1 = p1.stdNumber().doubleValue();
|
||||||
|
|
||||||
|
double m3 = p3.meanNumber().doubleValue();
|
||||||
|
double s3 = p3.stdNumber().doubleValue();
|
||||||
|
|
||||||
|
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
|
||||||
|
|
||||||
|
assertEquals(str, m1, m3, 0.1);
|
||||||
|
assertEquals(str, s1, s3, 0.1);
|
||||||
|
|
||||||
|
double re = relErr(s1, s3);
|
||||||
|
assertTrue(str + " - " + re, re < 0.05);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double relErr(double d1, double d2){
|
||||||
|
if(d1 == 0.0 && d2 == 0.0)
|
||||||
|
return 0.0;
|
||||||
|
return Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
@ -136,6 +137,11 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
private class ValidatingSameDiffVertex extends SameDiffVertex {
|
private class ValidatingSameDiffVertex extends SameDiffVertex {
|
||||||
|
@Override
|
||||||
|
public GraphVertex clone() {
|
||||||
|
return new ValidatingSameDiffVertex();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
||||||
return vertexInputs[0];
|
return vertexInputs[0];
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
|
@ -74,4 +75,9 @@ public class SameDiffDenseVertex extends SameDiffVertex {
|
||||||
public char paramReshapeOrder(String paramName){
|
public char paramReshapeOrder(String paramName){
|
||||||
return 'f'; //To match DL4J DenseLayer - for easy comparison
|
return 'f'; //To match DL4J DenseLayer - for easy comparison
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GraphVertex clone() {
|
||||||
|
return new SameDiffDenseVertex(nIn, nOut, activation, weightInit);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
|
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
|
||||||
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.AttentionVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
||||||
|
@ -35,6 +36,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -44,6 +46,9 @@ import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -565,4 +570,99 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName));
|
assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName));
|
||||||
newGraph.output(input);
|
newGraph.output(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTransferLearningSameDiffLayersGraph(){
|
||||||
|
|
||||||
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs("in")
|
||||||
|
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
|
||||||
|
.layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
|
||||||
|
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("out")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
|
cg.init();
|
||||||
|
|
||||||
|
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
INDArray out = cg.output(arr)[0];
|
||||||
|
|
||||||
|
|
||||||
|
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
|
||||||
|
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||||
|
.removeVertexAndConnections("out")
|
||||||
|
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("newOut")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
cg2.output(arr);
|
||||||
|
|
||||||
|
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
|
||||||
|
m.put("newOut_W", m.remove("out_W"));
|
||||||
|
m.put("newOut_b", m.remove("out_b"));
|
||||||
|
cg2.setParamTable(m);
|
||||||
|
|
||||||
|
Map<String,INDArray> p1 = cg.paramTable();
|
||||||
|
Map<String,INDArray> p2 = cg2.paramTable();
|
||||||
|
for(String s : p1.keySet()){
|
||||||
|
INDArray i1 = p1.get(s);
|
||||||
|
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
|
||||||
|
assertEquals(s, i1, i2);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray out2 = cg2.outputSingle(arr);
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTransferLearningSameDiffLayersGraphVertex(){
|
||||||
|
|
||||||
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs("in")
|
||||||
|
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
|
||||||
|
.addVertex("l1", new AttentionVertex.Builder().nHeads(1).headSize(5).nInKeys(5).nInQueries(5).nInValues(5).nOut(5).build(), "l0", "l0", "l0")
|
||||||
|
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("out")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
|
cg.init();
|
||||||
|
|
||||||
|
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
INDArray out = cg.output(arr)[0];
|
||||||
|
|
||||||
|
|
||||||
|
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
|
||||||
|
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||||
|
.removeVertexAndConnections("out")
|
||||||
|
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
|
.setOutputs("newOut")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
cg2.output(arr);
|
||||||
|
|
||||||
|
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
|
||||||
|
m.put("newOut_W", m.remove("out_W"));
|
||||||
|
m.put("newOut_b", m.remove("out_b"));
|
||||||
|
cg2.setParamTable(m);
|
||||||
|
|
||||||
|
Map<String,INDArray> p1 = cg.paramTable();
|
||||||
|
Map<String,INDArray> p2 = cg2.paramTable();
|
||||||
|
for(String s : p1.keySet()){
|
||||||
|
INDArray i1 = p1.get(s);
|
||||||
|
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
|
||||||
|
assertEquals(s, i1, i2);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray out2 = cg2.outputSingle(arr);
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ import org.deeplearning4j.nn.weights.WeightInitRelu;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -48,6 +49,8 @@ import org.nd4j.linalg.learning.config.*;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -689,4 +692,51 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2));
|
assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2));
|
||||||
newNet.output(input);
|
newNet.output(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTransferLearningSameDiffLayers(){
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.updater(new Adam(0.01))
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.list()
|
||||||
|
.layer(new LSTM.Builder().nOut(8).build())
|
||||||
|
.layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build())
|
||||||
|
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
|
||||||
|
.layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.recurrent(4))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 4, 5);
|
||||||
|
INDArray out = net.output(in);
|
||||||
|
|
||||||
|
MultiLayerNetwork net2 = new TransferLearning.Builder(net)
|
||||||
|
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
|
||||||
|
.removeLayersFromOutput(1)
|
||||||
|
.addLayer(new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
net2.setParam("3_W", net.getParam("3_W"));
|
||||||
|
net2.setParam("3_b", net.getParam("3_b"));
|
||||||
|
|
||||||
|
Map<String,INDArray> p1 = net.paramTable();
|
||||||
|
Map<String,INDArray> p2 = net2.paramTable();
|
||||||
|
for(String s : p1.keySet()){
|
||||||
|
INDArray i1 = p1.get(s);
|
||||||
|
INDArray i2 = p2.get(s);
|
||||||
|
assertEquals(s, i1, i2);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray out2 = net2.output(in);
|
||||||
|
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -427,7 +427,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices
|
if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices
|
||||||
throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
|
throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
|
||||||
+ ". Disconnected vertices are those that do not connect to either another vertex, and are also"
|
+ ". Disconnected vertices are those that do not connect to either another vertex, and are also"
|
||||||
+ " not a network output. To disable this error (i.e., allow network configurations with" +
|
+ " not a network output. This vertex can be set as an output using setOutputs(String...). "
|
||||||
|
+ "To disable this error (i.e., allow network configurations with" +
|
||||||
" disconnected vertices) use GraphBuilder.allowDisconnected(true)");
|
" disconnected vertices) use GraphBuilder.allowDisconnected(true)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,6 +72,20 @@ public class AttentionVertex extends SameDiffVertex {
|
||||||
this.weightInit = builder.weightInit;
|
this.weightInit = builder.weightInit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AttentionVertex clone() {
|
||||||
|
AttentionVertex av = new AttentionVertex();
|
||||||
|
av.nInKeys = nInKeys;
|
||||||
|
av.nInValues = nInValues;
|
||||||
|
av.nInQueries = nInQueries;
|
||||||
|
av.nOut = nOut;
|
||||||
|
av.headSize = headSize;
|
||||||
|
av.nHeads = nHeads;
|
||||||
|
av.projectInput = projectInput;
|
||||||
|
av.weightInit = weightInit;
|
||||||
|
return av;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
|
||||||
InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent) vertexInputs[0];
|
InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent) vertexInputs[0];
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
|
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||||
|
@ -79,7 +80,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public ParamInitializer initializer() {
|
||||||
return DefaultParamInitializer.getInstance();
|
return EmbeddingLayerParamInitializer.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||||
|
@ -92,7 +92,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public ParamInitializer initializer() {
|
||||||
return DefaultParamInitializer.getInstance();
|
return EmbeddingLayerParamInitializer.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,11 +16,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers.samediff;
|
package org.deeplearning4j.nn.conf.layers.samediff;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,6 +77,15 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
||||||
//No op, for lambda vertex
|
//No op, for lambda vertex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GraphVertex clone() {
|
||||||
|
try {
|
||||||
|
return getClass().getConstructor().newInstance();
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException("Unable to create new instance of class " + getClass().getName() + " from no-arg constructor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected VertexInputs getInputs(SameDiff sd) {
|
protected VertexInputs getInputs(SameDiff sd) {
|
||||||
if (inputs == null) {
|
if (inputs == null) {
|
||||||
inputs = new VertexInputs(sd);
|
inputs = new VertexInputs(sd);
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
|
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
|
||||||
|
@ -36,6 +37,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -99,11 +101,6 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
|
||||||
return vertexParams;
|
return vertexParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public GraphVertex clone() {
|
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long numParams(boolean backprop) {
|
public long numParams(boolean backprop) {
|
||||||
SDLayerParams params = getVertexParams();
|
SDLayerParams params = getVertexParams();
|
||||||
|
|
|
@ -3394,7 +3394,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
|
public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
|
||||||
Preconditions.checkArgument(paramTable.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal");
|
Map<String,INDArray> m = paramTable();
|
||||||
|
Preconditions.checkArgument(paramTable.keySet().equals(m.keySet()), "Cannot set param table: parameter set keys are not equal");
|
||||||
Map<String,INDArray> current = paramTable();
|
Map<String,INDArray> current = paramTable();
|
||||||
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
|
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
|
||||||
for(String s : current.keySet()){
|
for(String s : current.keySet()){
|
||||||
|
|
|
@ -237,9 +237,16 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setParams(INDArray params) {
|
public void setParams(INDArray params) {
|
||||||
if (params != null) {
|
if(this.params == null && params == null)
|
||||||
throw new UnsupportedOperationException("Not supported");
|
return;
|
||||||
}
|
if(this.params == null)
|
||||||
|
throw new IllegalStateException("Cannot set parameters of length " + params.length() + " to a layer with no parameters");
|
||||||
|
if(params == null)
|
||||||
|
throw new IllegalStateException("Cannot set null parameters");
|
||||||
|
|
||||||
|
Preconditions.checkState(this.params.length() == params.length(), "Cannot assign parameter vector of length %s to a layer with %s parameters",
|
||||||
|
params.length(), this.params.length());
|
||||||
|
this.params.assign(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void setParams(INDArray params, char order) {
|
protected void setParams(INDArray params, char order) {
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameter initializer for EmbeddingLayer and EmbeddingSequenceLayer
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class EmbeddingLayerParamInitializer extends DefaultParamInitializer {
|
||||||
|
|
||||||
|
private static final EmbeddingLayerParamInitializer INSTANCE = new EmbeddingLayerParamInitializer();
|
||||||
|
|
||||||
|
public static EmbeddingLayerParamInitializer getInstance() {
|
||||||
|
return INSTANCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
protected INDArray createWeightMatrix(long nIn, long nOut, IWeightInit weightInit,
|
||||||
|
INDArray weightParamView, boolean initializeParameters) {
|
||||||
|
val shape = new long[] {nIn, nOut};
|
||||||
|
|
||||||
|
if (initializeParameters) {
|
||||||
|
INDArray ret = weightInit.init(1, //Fan in - note that fanIn=1 for embedding layer... if we used layer nIn (i.e., vocab size) the init would depend on vocab size (which doesn't make sense)
|
||||||
|
nOut, //Fan out
|
||||||
|
shape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
return WeightInitUtil.reshapeWeights(shape, weightParamView);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -79,7 +79,6 @@ public abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats>
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<S> getFeatureStats() {
|
protected List<S> getFeatureStats() {
|
||||||
assertIsFit();
|
|
||||||
return featureStats;
|
return featureStats;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +87,6 @@ public abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats>
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<S> getLabelStats() {
|
protected List<S> getLabelStats() {
|
||||||
assertIsFit();
|
|
||||||
return labelStats;
|
return labelStats;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
||||||
private double minRange, maxRange;
|
private double minRange, maxRange;
|
||||||
private double maxPixelVal;
|
private double maxPixelVal;
|
||||||
private int maxBits;
|
private int maxBits;
|
||||||
|
private boolean fitLabels = false;
|
||||||
|
|
||||||
public ImagePreProcessingScaler() {
|
public ImagePreProcessingScaler() {
|
||||||
this(0, 1, 8);
|
this(0, 1, 8);
|
||||||
|
@ -94,7 +95,10 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
||||||
@Override
|
@Override
|
||||||
public void preProcess(DataSet toPreProcess) {
|
public void preProcess(DataSet toPreProcess) {
|
||||||
INDArray features = toPreProcess.getFeatures();
|
INDArray features = toPreProcess.getFeatures();
|
||||||
this.preProcess(features);
|
preProcess(features);
|
||||||
|
if(fitLabels && toPreProcess.getLabels() != null){
|
||||||
|
preProcess(toPreProcess.getLabels());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void preProcess(INDArray features) {
|
public void preProcess(INDArray features) {
|
||||||
|
@ -139,6 +143,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
||||||
@Override
|
@Override
|
||||||
public void revert(DataSet toRevert) {
|
public void revert(DataSet toRevert) {
|
||||||
revertFeatures(toRevert.getFeatures());
|
revertFeatures(toRevert.getFeatures());
|
||||||
|
revertLabels(toRevert.getLabels());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -177,10 +182,11 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
||||||
@Override
|
@Override
|
||||||
public void fitLabel(boolean fitLabels) {
|
public void fitLabel(boolean fitLabels) {
|
||||||
//No-op
|
//No-op
|
||||||
|
this.fitLabels = fitLabels;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isFitLabel() {
|
public boolean isFitLabel() {
|
||||||
return false;
|
return fitLabels;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
|
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -156,6 +157,42 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
|
||||||
assertEquals(orig, before);
|
assertEquals(orig, before);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSegmentation(){
|
||||||
|
|
||||||
|
INDArray f = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 3, 16, 16).muli(255));
|
||||||
|
INDArray l = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 10, 8, 8).muli(255));
|
||||||
|
|
||||||
|
ImagePreProcessingScaler s = new ImagePreProcessingScaler();
|
||||||
|
s.fitLabel(true);
|
||||||
|
|
||||||
|
s.fit(new DataSet(f,l));
|
||||||
|
|
||||||
|
INDArray expF = f.div(255);
|
||||||
|
INDArray expL = l.div(255);
|
||||||
|
|
||||||
|
DataSet d = new DataSet(f.dup(), l.dup());
|
||||||
|
s.transform(d);
|
||||||
|
assertEquals(expF, d.getFeatures());
|
||||||
|
assertEquals(expL, d.getLabels());
|
||||||
|
|
||||||
|
|
||||||
|
s.fit(new SingletonDataSetIterator(new DataSet(f, l)));
|
||||||
|
|
||||||
|
INDArray f2 = f.dup();
|
||||||
|
INDArray l2 = l.dup();
|
||||||
|
s.transform(f2);
|
||||||
|
s.transformLabel(l2);
|
||||||
|
assertEquals(expF, f2);
|
||||||
|
assertEquals(expL, l2);
|
||||||
|
|
||||||
|
s.revertFeatures(f2);
|
||||||
|
s.revertLabels(l2);
|
||||||
|
assertEquals(f, f2);
|
||||||
|
assertEquals(l, l2);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
|
@ -22,11 +22,10 @@ import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
import org.nd4j.linalg.dataset.api.preprocessor.*;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -189,11 +188,11 @@ public class NormalizerTests extends BaseNd4jTest {
|
||||||
|
|
||||||
INDArray zeros = Nd4j.zeros(shouldBe0_1.shape());
|
INDArray zeros = Nd4j.zeros(shouldBe0_1.shape());
|
||||||
|
|
||||||
for (int j = 0; j < 2; j++) {
|
// for (int j = 0; j < 2; j++) {
|
||||||
System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(),
|
// System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(),
|
||||||
NDArrayIndex.all()));
|
// NDArrayIndex.all()));
|
||||||
System.out.println();
|
// System.out.println();
|
||||||
}
|
// }
|
||||||
|
|
||||||
assertEquals(zeros, shouldBe0_1);
|
assertEquals(zeros, shouldBe0_1);
|
||||||
assertEquals(zeros, shouldBe0_2);
|
assertEquals(zeros, shouldBe0_2);
|
||||||
|
@ -218,6 +217,69 @@ public class NormalizerTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNormalizerToStringHashCode(){
|
||||||
|
//https://github.com/eclipse/deeplearning4j/issues/8565
|
||||||
|
|
||||||
|
testNormalizer(new NormalizerMinMaxScaler());
|
||||||
|
NormalizerMinMaxScaler n1 = new NormalizerMinMaxScaler();
|
||||||
|
n1.fitLabel(true);
|
||||||
|
testNormalizer(n1);
|
||||||
|
|
||||||
|
testNormalizer(new NormalizerStandardize());
|
||||||
|
NormalizerStandardize n2 = new NormalizerStandardize();
|
||||||
|
n2.fitLabel(true);
|
||||||
|
testNormalizer(n2);
|
||||||
|
|
||||||
|
testNormalizer(new ImagePreProcessingScaler());
|
||||||
|
ImagePreProcessingScaler n3 = new ImagePreProcessingScaler();
|
||||||
|
n3.fitLabel(true);
|
||||||
|
testNormalizer(n3);
|
||||||
|
|
||||||
|
testNormalizer(new VGG16ImagePreProcessor());
|
||||||
|
VGG16ImagePreProcessor n4 = new VGG16ImagePreProcessor();
|
||||||
|
n4.fitLabel(true);
|
||||||
|
testNormalizer(n4);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void testNormalizer(DataNormalization n){
|
||||||
|
n.toString();
|
||||||
|
n.hashCode();
|
||||||
|
|
||||||
|
n.fit(new IrisDataSetIterator(30, 150));
|
||||||
|
|
||||||
|
n.toString();
|
||||||
|
n.hashCode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiNormalizerToStringHashCode(){
|
||||||
|
//https://github.com/eclipse/deeplearning4j/issues/8565
|
||||||
|
|
||||||
|
testMultiNormalizer(new MultiNormalizerMinMaxScaler());
|
||||||
|
MultiNormalizerMinMaxScaler n1 = new MultiNormalizerMinMaxScaler();
|
||||||
|
n1.fitLabel(true);
|
||||||
|
testMultiNormalizer(n1);
|
||||||
|
|
||||||
|
testMultiNormalizer(new MultiNormalizerStandardize());
|
||||||
|
MultiNormalizerStandardize n2 = new MultiNormalizerStandardize();
|
||||||
|
n2.fitLabel(true);
|
||||||
|
testMultiNormalizer(n2);
|
||||||
|
|
||||||
|
testMultiNormalizer(new ImageMultiPreProcessingScaler(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void testMultiNormalizer(MultiDataNormalization n){
|
||||||
|
n.toString();
|
||||||
|
n.hashCode();
|
||||||
|
|
||||||
|
n.fit(new MultiDataSetIteratorAdapter(new IrisDataSetIterator(30, 150)));
|
||||||
|
|
||||||
|
n.toString();
|
||||||
|
n.hashCode();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue