DL4J: Fix 2 JSON issues [WIP] (#490)

* Fix MergeVertex serialization for NHWC case

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8999 Dropout JSON field ignore

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-06-11 12:37:38 +10:00 committed by GitHub
parent f30acad57d
commit fadc2d8622
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 6 deletions

View File

@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.junit.AfterClass; import org.junit.*;
import org.junit.Before; import org.junit.rules.TemporaryFolder;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
@ -91,6 +90,9 @@ import static org.junit.Assert.*;
@Slf4j @Slf4j
public class TestComputationGraphNetwork extends BaseDL4JTest { public class TestComputationGraphNetwork extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private static ComputationGraphConfiguration getIrisGraphConfiguration() { private static ComputationGraphConfiguration getIrisGraphConfiguration() {
return new NeuralNetConfiguration.Builder().seed(12345) return new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
cg.fit(new DataSet(in, label)); cg.fit(new DataSet(in, label));
} }
@Test
public void testMergeNchw() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.convolutionMode(ConvolutionMode.Same)
.graphBuilder()
.addInputs("in")
.layer("l0", new ConvolutionLayer.Builder()
.nOut(16)
.kernelSize(2,2).stride(1,1)
.build(), "in")
.layer("l1", new ConvolutionLayer.Builder()
.nOut(8)
.kernelSize(2,2).stride(1,1)
.build(), "in")
.addVertex("merge", new MergeVertex(), "l0", "l1")
.layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
.setOutputs("out")
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
.build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)};
INDArray out = cg.outputSingle(in);
File dir = testDir.newFolder();
File f = new File(dir, "net.zip");
cg.save(f);
ComputationGraph c2 = ComputationGraph.load(f, true);
INDArray out2 = c2.outputSingle(in);
assertEquals(out, out2);
}
} }

View File

@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
* @author Alex Black * @author Alex Black
*/ */
@Data @Data
@JsonIgnoreProperties({"mask", "helper", "helperCountFail"}) @JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"}) @EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
@Slf4j @Slf4j
public class Dropout implements IDropout { public class Dropout implements IDropout {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.graph; package org.deeplearning4j.nn.conf.graph;
import lombok.Data;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* -> [numExamples,depth1 + depth2,width,height]}<br> * -> [numExamples,depth1 + depth2,width,height]}<br>
* @author Alex Black * @author Alex Black
*/ */
@Data
public class MergeVertex extends GraphVertex { public class MergeVertex extends GraphVertex {
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format