DL4J: Fix 2 JSON issues [WIP] (#490)
* Fix MergeVertex serialization for NHWC case Signed-off-by: Alex Black <blacka101@gmail.com> * #8999 Dropout JSON field ignore Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
f30acad57d
commit
fadc2d8622
|
@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.junit.AfterClass;
|
import org.junit.*;
|
||||||
import org.junit.Before;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -91,6 +90,9 @@ import static org.junit.Assert.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestComputationGraphNetwork extends BaseDL4JTest {
|
public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
private static ComputationGraphConfiguration getIrisGraphConfiguration() {
|
private static ComputationGraphConfiguration getIrisGraphConfiguration() {
|
||||||
return new NeuralNetConfiguration.Builder().seed(12345)
|
return new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
|
||||||
|
@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
|
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
|
||||||
cg.fit(new DataSet(in, label));
|
cg.fit(new DataSet(in, label));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMergeNchw() throws Exception {
|
||||||
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs("in")
|
||||||
|
.layer("l0", new ConvolutionLayer.Builder()
|
||||||
|
.nOut(16)
|
||||||
|
.kernelSize(2,2).stride(1,1)
|
||||||
|
.build(), "in")
|
||||||
|
.layer("l1", new ConvolutionLayer.Builder()
|
||||||
|
.nOut(8)
|
||||||
|
.kernelSize(2,2).stride(1,1)
|
||||||
|
.build(), "in")
|
||||||
|
.addVertex("merge", new MergeVertex(), "l0", "l1")
|
||||||
|
.layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
|
||||||
|
.setOutputs("out")
|
||||||
|
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
|
cg.init();
|
||||||
|
|
||||||
|
INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)};
|
||||||
|
INDArray out = cg.outputSingle(in);
|
||||||
|
|
||||||
|
File dir = testDir.newFolder();
|
||||||
|
File f = new File(dir, "net.zip");
|
||||||
|
cg.save(f);
|
||||||
|
|
||||||
|
ComputationGraph c2 = ComputationGraph.load(f, true);
|
||||||
|
INDArray out2 = c2.outputSingle(in);
|
||||||
|
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties({"mask", "helper", "helperCountFail"})
|
@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
|
||||||
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"})
|
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class Dropout implements IDropout {
|
public class Dropout implements IDropout {
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
* -> [numExamples,depth1 + depth2,width,height]}<br>
|
* -> [numExamples,depth1 + depth2,width,height]}<br>
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@Data
|
||||||
public class MergeVertex extends GraphVertex {
|
public class MergeVertex extends GraphVertex {
|
||||||
|
|
||||||
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
|
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
|
||||||
|
|
Loading…
Reference in New Issue