DL4J: Fix 2 JSON issues [WIP] (#490)
* Fix MergeVertex serialization for NHWC case Signed-off-by: Alex Black <blacka101@gmail.com> * #8999 Dropout JSON field ignore Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									f30acad57d
								
							
						
					
					
						commit
						fadc2d8622
					
				| @ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit; | ||||
| import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; | ||||
| import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | ||||
| import org.deeplearning4j.util.ModelSerializer; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.Before; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Test; | ||||
| import org.junit.*; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.linalg.activations.Activation; | ||||
| import org.nd4j.linalg.activations.impl.ActivationIdentity; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| @ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources; | ||||
| 
 | ||||
| import java.io.ByteArrayInputStream; | ||||
| import java.io.ByteArrayOutputStream; | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.*; | ||||
| 
 | ||||
| @ -91,6 +90,9 @@ import static org.junit.Assert.*; | ||||
| @Slf4j | ||||
| public class TestComputationGraphNetwork extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     private static ComputationGraphConfiguration getIrisGraphConfiguration() { | ||||
|         return new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() | ||||
| @ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { | ||||
|         INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); | ||||
|         cg.fit(new DataSet(in, label)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testMergeNchw() throws Exception { | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                 .convolutionMode(ConvolutionMode.Same) | ||||
|                 .graphBuilder() | ||||
|                 .addInputs("in") | ||||
|                 .layer("l0", new ConvolutionLayer.Builder() | ||||
|                         .nOut(16) | ||||
|                         .kernelSize(2,2).stride(1,1) | ||||
|                         .build(), "in") | ||||
|                 .layer("l1", new ConvolutionLayer.Builder() | ||||
|                         .nOut(8) | ||||
|                         .kernelSize(2,2).stride(1,1) | ||||
|                         .build(), "in") | ||||
|                 .addVertex("merge", new MergeVertex(), "l0", "l1") | ||||
|                 .layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge") | ||||
|                 .setOutputs("out") | ||||
|                 .setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC)) | ||||
|                 .build(); | ||||
| 
 | ||||
|         ComputationGraph cg = new ComputationGraph(conf); | ||||
|         cg.init(); | ||||
| 
 | ||||
|         INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)}; | ||||
|         INDArray out = cg.outputSingle(in); | ||||
| 
 | ||||
|         File dir = testDir.newFolder(); | ||||
|         File f = new File(dir, "net.zip"); | ||||
|         cg.save(f); | ||||
| 
 | ||||
|         ComputationGraph c2 = ComputationGraph.load(f, true); | ||||
|         INDArray out2 = c2.outputSingle(in); | ||||
| 
 | ||||
|         assertEquals(out, out2); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty; | ||||
|  * @author Alex Black | ||||
|  */ | ||||
| @Data | ||||
| @JsonIgnoreProperties({"mask", "helper", "helperCountFail"}) | ||||
| @EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"}) | ||||
| @JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"}) | ||||
| @EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"}) | ||||
| @Slf4j | ||||
| public class Dropout implements IDropout { | ||||
| 
 | ||||
|  | ||||
| @ -17,6 +17,7 @@ | ||||
| package org.deeplearning4j.nn.conf.graph; | ||||
| 
 | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import lombok.val; | ||||
| import org.deeplearning4j.nn.conf.CNN2DFormat; | ||||
| import org.deeplearning4j.nn.conf.RNNFormat; | ||||
| @ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; | ||||
|  *      -> [numExamples,depth1 + depth2,width,height]}<br> | ||||
|  * @author Alex Black | ||||
|  */ | ||||
| @Data | ||||
| public class MergeVertex extends GraphVertex { | ||||
| 
 | ||||
|     protected int mergeAxis = 1;       //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user