diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 743e16710..b0cc17376 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.*; +import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; @@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; import java.util.*; @@ -91,6 +90,9 @@ import static org.junit.Assert.*; @Slf4j public class TestComputationGraphNetwork extends BaseDL4JTest { + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + private static ComputationGraphConfiguration getIrisGraphConfiguration() { return new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() @@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); cg.fit(new DataSet(in, label)); } + + @Test + public void testMergeNchw() throws Exception { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .convolutionMode(ConvolutionMode.Same) + .graphBuilder() + .addInputs("in") + .layer("l0", new ConvolutionLayer.Builder() + .nOut(16) + .kernelSize(2,2).stride(1,1) + .build(), "in") + .layer("l1", new ConvolutionLayer.Builder() + .nOut(8) + .kernelSize(2,2).stride(1,1) + .build(), "in") + .addVertex("merge", new MergeVertex(), "l0", "l1") + .layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge") + .setOutputs("out") + .setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC)) + .build(); + + ComputationGraph cg = new ComputationGraph(conf); + cg.init(); + + INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)}; + INDArray out = cg.outputSingle(in); + + File dir = testDir.newFolder(); + File f = new File(dir, "net.zip"); + cg.save(f); + + ComputationGraph c2 = ComputationGraph.load(f, true); + INDArray out2 = c2.outputSingle(in); + + assertEquals(out, out2); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java index 46a872fd8..acb6afa2c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java @@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty; * @author Alex Black */ @Data -@JsonIgnoreProperties({"mask", "helper", "helperCountFail"}) -@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"}) +@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"}) +@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"}) @Slf4j public class Dropout implements IDropout { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java index 726a68403..c7a4fec63 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.graph; +import lombok.Data; import lombok.val; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.RNNFormat; @@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * -> [numExamples,depth1 + depth2,width,height]}
* @author Alex Black */ +@Data public class MergeVertex extends GraphVertex { protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format