Assorted SameDiff/DL4J fixes (#279)

* #8565 Normalizer toString/hashcode

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8731 ImagePreProcessingScaler lables/segmentation fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8691 Fix SameDiffLayer/Vertx finetuning and parameter setting support

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8663 DL4J embedding layer weight init - don't depend on vocab size

Signed-off-by: Alex Black <blacka101@gmail.com>

* EmbeddingLayer test tweak

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-02 16:15:49 +11:00 committed by GitHub
parent f116f53d61
commit 483c3d7b8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 449 additions and 27 deletions

View File

@ -47,8 +47,7 @@ import java.util.List;
import java.util.Map;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;
public class EmbeddingLayerTest extends BaseDL4JTest {
@ -725,4 +724,77 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
assertEquals(new ActivationIdentity(), l2.getActivationFn());
}
@Test
public void testEmbeddingWeightInit(){
// https://github.com/eclipse/deeplearning4j/issues/8663
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, WeightInit.VAR_SCALING_NORMAL_FAN_OUT}) {
for (boolean seq : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(seq ?
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(seq ?
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() :
new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(seq ?
new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build() :
new EmbeddingLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build())
.build();
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
net3.init();
INDArray p1 = net.params();
INDArray p2 = net2.params();
INDArray p3 = net3.params();
assertEquals(p1, p2);
double m1 = p1.meanNumber().doubleValue();
double s1 = p1.stdNumber().doubleValue();
double m3 = p3.meanNumber().doubleValue();
double s3 = p3.stdNumber().doubleValue();
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
assertEquals(str, m1, m3, 0.1);
assertEquals(str, s1, s3, 0.1);
double re = relErr(s1, s3);
assertTrue(str + " - " + re, re < 0.05);
}
}
}
public static double relErr(double d1, double d2){
if(d1 == 0.0 && d2 == 0.0)
return 0.0;
return Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2));
}
}

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
@ -136,6 +137,11 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest {
}
private class ValidatingSameDiffVertex extends SameDiffVertex {
@Override
public GraphVertex clone() {
return new ValidatingSameDiffVertex();
}
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
return vertexInputs[0];

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.samediff.testlayers;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
@ -74,4 +75,9 @@ public class SameDiffDenseVertex extends SameDiffVertex {
public char paramReshapeOrder(String paramName){
return 'f'; //To match DL4J DenseLayer - for easy comparison
}
@Override
public GraphVertex clone() {
return new SameDiffDenseVertex(nIn, nOut, activation, weightInit);
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.samediff.testlayers;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.graph.AttentionVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
@ -35,6 +36,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@ -44,6 +46,9 @@ import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.*;
/**
@ -565,4 +570,99 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName));
newGraph.output(input);
}
@Test
public void testTransferLearningSameDiffLayersGraph(){
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("in")
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
.layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
.setOutputs("out")
.build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray out = cg.output(arr)[0];
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
.removeVertexAndConnections("out")
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
.setOutputs("newOut")
.build();
cg2.output(arr);
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
m.put("newOut_W", m.remove("out_W"));
m.put("newOut_b", m.remove("out_b"));
cg2.setParamTable(m);
Map<String,INDArray> p1 = cg.paramTable();
Map<String,INDArray> p2 = cg2.paramTable();
for(String s : p1.keySet()){
INDArray i1 = p1.get(s);
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
assertEquals(s, i1, i2);
}
INDArray out2 = cg2.outputSingle(arr);
assertEquals(out, out2);
}
@Test
public void testTransferLearningSameDiffLayersGraphVertex(){
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("in")
.layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in")
.addVertex("l1", new AttentionVertex.Builder().nHeads(1).headSize(5).nInKeys(5).nInQueries(5).nInValues(5).nOut(5).build(), "l0", "l0", "l0")
.layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
.setOutputs("out")
.build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray out = cg.output(arr)[0];
ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out")
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
.removeVertexAndConnections("out")
.addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
.setOutputs("newOut")
.build();
cg2.output(arr);
Map<String,INDArray> m = new HashMap<>(cg.paramTable());
m.put("newOut_W", m.remove("out_W"));
m.put("newOut_b", m.remove("out_b"));
cg2.setParamTable(m);
Map<String,INDArray> p1 = cg.paramTable();
Map<String,INDArray> p2 = cg2.paramTable();
for(String s : p1.keySet()){
INDArray i1 = p1.get(s);
INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
assertEquals(s, i1, i2);
}
INDArray out2 = cg2.outputSingle(arr);
assertEquals(out, out2);
}
}

View File

@ -41,6 +41,7 @@ import org.deeplearning4j.nn.weights.WeightInitRelu;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@ -48,6 +49,8 @@ import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.util.Map;
import static org.junit.Assert.*;
/**
@ -689,4 +692,51 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2));
newNet.output(input);
}
@Test
public void testTransferLearningSameDiffLayers(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new Adam(0.01))
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(8).build())
.layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build())
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
.layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(4))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 4, 5);
INDArray out = net.output(in);
MultiLayerNetwork net2 = new TransferLearning.Builder(net)
.fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build())
.removeLayersFromOutput(1)
.addLayer(new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
net2.setParam("3_W", net.getParam("3_W"));
net2.setParam("3_b", net.getParam("3_b"));
Map<String,INDArray> p1 = net.paramTable();
Map<String,INDArray> p2 = net2.paramTable();
for(String s : p1.keySet()){
INDArray i1 = p1.get(s);
INDArray i2 = p2.get(s);
assertEquals(s, i1, i2);
}
INDArray out2 = net2.output(in);
assertEquals(out, out2);
}
}

View File

@ -427,7 +427,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices
throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
+ ". Disconnected vertices are those that do not connect to either another vertex, and are also"
+ " not a network output. To disable this error (i.e., allow network configurations with" +
+ " not a network output. This vertex can be set as an output using setOutputs(String...). "
+ "To disable this error (i.e., allow network configurations with" +
" disconnected vertices) use GraphBuilder.allowDisconnected(true)");
}
}

View File

@ -72,6 +72,20 @@ public class AttentionVertex extends SameDiffVertex {
this.weightInit = builder.weightInit;
}
@Override
public AttentionVertex clone() {
AttentionVertex av = new AttentionVertex();
av.nInKeys = nInKeys;
av.nInValues = nInValues;
av.nInQueries = nInQueries;
av.nOut = nOut;
av.headSize = headSize;
av.nHeads = nHeads;
av.projectInput = projectInput;
av.weightInit = weightInit;
return av;
}
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent) vertexInputs[0];

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
@ -79,7 +80,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
@Override
public ParamInitializer initializer() {
return DefaultParamInitializer.getInstance();
return EmbeddingLayerParamInitializer.getInstance();
}
@Override

View File

@ -24,7 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
@ -92,7 +92,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
@Override
public ParamInitializer initializer() {
return DefaultParamInitializer.getInstance();
return EmbeddingLayerParamInitializer.getInstance();
}
@Override

View File

@ -16,11 +16,13 @@
package org.deeplearning4j.nn.conf.layers.samediff;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.lang.reflect.InvocationTargetException;
import java.util.*;
@ -75,6 +77,15 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
//No op, for lambda vertex
}
@Override
public GraphVertex clone() {
try {
return getClass().getConstructor().newInstance();
} catch (Exception e){
throw new RuntimeException("Unable to create new instance of class " + getClass().getName() + " from no-arg constructor");
}
}
protected VertexInputs getInputs(SameDiff sd) {
if (inputs == null) {
inputs = new VertexInputs(sd);

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
@ -36,6 +37,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import java.lang.reflect.Field;
import java.util.List;
import java.util.Map;
@ -99,11 +101,6 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
return vertexParams;
}
@Override
public GraphVertex clone() {
throw new UnsupportedOperationException("Not yet implemented");
}
@Override
public long numParams(boolean backprop) {
SDLayerParams params = getVertexParams();

View File

@ -3394,7 +3394,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
@Override
public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
Preconditions.checkArgument(paramTable.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal");
Map<String,INDArray> m = paramTable();
Preconditions.checkArgument(paramTable.keySet().equals(m.keySet()), "Cannot set param table: parameter set keys are not equal");
Map<String,INDArray> current = paramTable();
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
for(String s : current.keySet()){

View File

@ -237,9 +237,16 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
@Override
public void setParams(INDArray params) {
if (params != null) {
throw new UnsupportedOperationException("Not supported");
}
if(this.params == null && params == null)
return;
if(this.params == null)
throw new IllegalStateException("Cannot set parameters of length " + params.length() + " to a layer with no parameters");
if(params == null)
throw new IllegalStateException("Cannot set null parameters");
Preconditions.checkState(this.params.length() == params.length(), "Cannot assign parameter vector of length %s to a layer with %s parameters",
params.length(), this.params.length());
this.params.assign(params);
}
protected void setParams(INDArray params, char order) {

View File

@ -0,0 +1,52 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* Parameter initializer for EmbeddingLayer and EmbeddingSequenceLayer
*
* @author Alex Black
*/
public class EmbeddingLayerParamInitializer extends DefaultParamInitializer {
private static final EmbeddingLayerParamInitializer INSTANCE = new EmbeddingLayerParamInitializer();
public static EmbeddingLayerParamInitializer getInstance() {
return INSTANCE;
}
protected INDArray createWeightMatrix(long nIn, long nOut, IWeightInit weightInit,
INDArray weightParamView, boolean initializeParameters) {
val shape = new long[] {nIn, nOut};
if (initializeParameters) {
INDArray ret = weightInit.init(1, //Fan in - note that fanIn=1 for embedding layer... if we used layer nIn (i.e., vocab size) the init would depend on vocab size (which doesn't make sense)
nOut, //Fan out
shape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
return ret;
} else {
return WeightInitUtil.reshapeWeights(shape, weightParamView);
}
}
}

View File

@ -79,7 +79,6 @@ public abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats>
}
protected List<S> getFeatureStats() {
assertIsFit();
return featureStats;
}
@ -88,7 +87,6 @@ public abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats>
}
protected List<S> getLabelStats() {
assertIsFit();
return labelStats;
}

View File

@ -44,6 +44,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
private double minRange, maxRange;
private double maxPixelVal;
private int maxBits;
private boolean fitLabels = false;
public ImagePreProcessingScaler() {
this(0, 1, 8);
@ -94,7 +95,10 @@ public class ImagePreProcessingScaler implements DataNormalization {
@Override
public void preProcess(DataSet toPreProcess) {
INDArray features = toPreProcess.getFeatures();
this.preProcess(features);
preProcess(features);
if(fitLabels && toPreProcess.getLabels() != null){
preProcess(toPreProcess.getLabels());
}
}
public void preProcess(INDArray features) {
@ -139,6 +143,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
@Override
public void revert(DataSet toRevert) {
revertFeatures(toRevert.getFeatures());
revertLabels(toRevert.getLabels());
}
@Override
@ -177,10 +182,11 @@ public class ImagePreProcessingScaler implements DataNormalization {
@Override
public void fitLabel(boolean fitLabels) {
//No-op
this.fitLabels = fitLabels;
}
@Override
public boolean isFitLabel() {
return false;
return fitLabels;
}
}

View File

@ -22,6 +22,7 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
@ -156,6 +157,42 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
assertEquals(orig, before);
}
@Test
public void testSegmentation(){
INDArray f = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 3, 16, 16).muli(255));
INDArray l = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 10, 8, 8).muli(255));
ImagePreProcessingScaler s = new ImagePreProcessingScaler();
s.fitLabel(true);
s.fit(new DataSet(f,l));
INDArray expF = f.div(255);
INDArray expL = l.div(255);
DataSet d = new DataSet(f.dup(), l.dup());
s.transform(d);
assertEquals(expF, d.getFeatures());
assertEquals(expL, d.getLabels());
s.fit(new SingletonDataSetIterator(new DataSet(f, l)));
INDArray f2 = f.dup();
INDArray l2 = l.dup();
s.transform(f2);
s.transformLabel(l2);
assertEquals(expF, f2);
assertEquals(expL, l2);
s.revertFeatures(f2);
s.revertLabels(l2);
assertEquals(f, f2);
assertEquals(l, l2);
}
@Override
public char ordering() {
return 'c';

View File

@ -22,11 +22,10 @@ import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.*;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -189,11 +188,11 @@ public class NormalizerTests extends BaseNd4jTest {
INDArray zeros = Nd4j.zeros(shouldBe0_1.shape());
for (int j = 0; j < 2; j++) {
System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(),
NDArrayIndex.all()));
System.out.println();
}
// for (int j = 0; j < 2; j++) {
// System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(),
// NDArrayIndex.all()));
// System.out.println();
// }
assertEquals(zeros, shouldBe0_1);
assertEquals(zeros, shouldBe0_2);
@ -218,6 +217,69 @@ public class NormalizerTests extends BaseNd4jTest {
}
}
@Test
public void testNormalizerToStringHashCode(){
//https://github.com/eclipse/deeplearning4j/issues/8565
testNormalizer(new NormalizerMinMaxScaler());
NormalizerMinMaxScaler n1 = new NormalizerMinMaxScaler();
n1.fitLabel(true);
testNormalizer(n1);
testNormalizer(new NormalizerStandardize());
NormalizerStandardize n2 = new NormalizerStandardize();
n2.fitLabel(true);
testNormalizer(n2);
testNormalizer(new ImagePreProcessingScaler());
ImagePreProcessingScaler n3 = new ImagePreProcessingScaler();
n3.fitLabel(true);
testNormalizer(n3);
testNormalizer(new VGG16ImagePreProcessor());
VGG16ImagePreProcessor n4 = new VGG16ImagePreProcessor();
n4.fitLabel(true);
testNormalizer(n4);
}
private static void testNormalizer(DataNormalization n){
n.toString();
n.hashCode();
n.fit(new IrisDataSetIterator(30, 150));
n.toString();
n.hashCode();
}
@Test
public void testMultiNormalizerToStringHashCode(){
//https://github.com/eclipse/deeplearning4j/issues/8565
testMultiNormalizer(new MultiNormalizerMinMaxScaler());
MultiNormalizerMinMaxScaler n1 = new MultiNormalizerMinMaxScaler();
n1.fitLabel(true);
testMultiNormalizer(n1);
testMultiNormalizer(new MultiNormalizerStandardize());
MultiNormalizerStandardize n2 = new MultiNormalizerStandardize();
n2.fitLabel(true);
testMultiNormalizer(n2);
testMultiNormalizer(new ImageMultiPreProcessingScaler(0));
}
private static void testMultiNormalizer(MultiDataNormalization n){
n.toString();
n.hashCode();
n.fit(new MultiDataSetIteratorAdapter(new IrisDataSetIterator(30, 150)));
n.toString();
n.hashCode();
}
@Override
public char ordering() {
return 'c';