397 lines
21 KiB
Java
397 lines
21 KiB
Java
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
package org.deeplearning4j.nn.layers;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.nn.api.Layer;
|
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
|
|
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
import org.junit.Test;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.dataset.DataSet;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.learning.config.Sgd;
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
import java.util.List;
|
|
|
|
import static org.junit.Assert.assertEquals;
|
|
import static org.junit.Assert.assertNotEquals;
|
|
import static org.junit.Assert.assertNotNull;
|
|
|
|
@Slf4j
|
|
public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|
|
|
@Test
|
|
public void testFrozenWithBackpropLayerInstantiation() {
|
|
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
|
|
// they were initialized via the builder
|
|
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list()
|
|
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
|
|
.weightInit(WeightInit.XAVIER).build())
|
|
.layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
|
|
.weightInit(WeightInit.XAVIER).build())
|
|
.layer(2, new OutputLayer.Builder(
|
|
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
|
|
.nOut(10).build())
|
|
.build();
|
|
|
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0,
|
|
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10)
|
|
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()))
|
|
.layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
|
|
.weightInit(WeightInit.XAVIER).build()))
|
|
.layer(2, new OutputLayer.Builder(
|
|
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
|
|
.nOut(10).build())
|
|
.build();
|
|
|
|
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
|
net1.init();
|
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
|
net2.init();
|
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
|
|
String json = conf2.toJson();
|
|
MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
|
|
|
|
assertEquals(conf2, fromJson);
|
|
|
|
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
|
|
net3.init();
|
|
|
|
INDArray input = Nd4j.rand(10, 10);
|
|
|
|
INDArray out2 = net2.output(input);
|
|
INDArray out3 = net3.output(input);
|
|
|
|
assertEquals(out2, out3);
|
|
}
|
|
|
|
@Test
|
|
public void testFrozenLayerInstantiationCompGraph() {
|
|
|
|
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
|
|
// they were initialized via the builder
|
|
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
|
|
.addInputs("in")
|
|
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
|
|
.weightInit(WeightInit.XAVIER).build(), "in")
|
|
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
|
|
.weightInit(WeightInit.XAVIER).build(), "0")
|
|
.addLayer("2", new OutputLayer.Builder(
|
|
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
|
|
.nOut(10).build(),
|
|
"1")
|
|
.setOutputs("2").build();
|
|
|
|
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
|
|
.addInputs("in")
|
|
.addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
|
|
.weightInit(WeightInit.XAVIER).build()), "in")
|
|
.addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
|
|
.weightInit(WeightInit.XAVIER).build()), "0")
|
|
.addLayer("2", new OutputLayer.Builder(
|
|
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
|
|
.nOut(10).build(),
|
|
"1")
|
|
.setOutputs("2").build();
|
|
|
|
ComputationGraph net1 = new ComputationGraph(conf1);
|
|
net1.init();
|
|
ComputationGraph net2 = new ComputationGraph(conf2);
|
|
net2.init();
|
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
|
|
String json = conf2.toJson();
|
|
ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json);
|
|
|
|
assertEquals(conf2, fromJson);
|
|
|
|
ComputationGraph net3 = new ComputationGraph(fromJson);
|
|
net3.init();
|
|
|
|
INDArray input = Nd4j.rand(10, 10);
|
|
|
|
INDArray out2 = net2.outputSingle(input);
|
|
INDArray out3 = net3.outputSingle(input);
|
|
|
|
assertEquals(out2, out3);
|
|
}
|
|
|
|
@Test
|
|
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
|
|
Nd4j.getRandom().setSeed(12345);
|
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
|
|
|
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
|
.seed(12345)
|
|
.weightInit(WeightInit.XAVIER)
|
|
.updater(new Sgd(2))
|
|
.list()
|
|
.layer(new DenseLayer.Builder().nIn(4).nOut(3).build())
|
|
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(3).nOut(4).build()))
|
|
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(4).nOut(2).build()))
|
|
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
|
|
.build();
|
|
|
|
MultiLayerNetwork network = new MultiLayerNetwork(conf1);
|
|
network.init();
|
|
INDArray unfrozenLayerParams = network.getLayer(0).params().dup();
|
|
INDArray frozenLayerParams1 = network.getLayer(1).params().dup();
|
|
INDArray frozenLayerParams2 = network.getLayer(2).params().dup();
|
|
INDArray frozenOutputLayerParams = network.getLayer(3).params().dup();
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
network.fit(randomData);
|
|
}
|
|
|
|
assertNotEquals(unfrozenLayerParams, network.getLayer(0).params());
|
|
assertEquals(frozenLayerParams1, network.getLayer(1).params());
|
|
assertEquals(frozenLayerParams2, network.getLayer(2).params());
|
|
assertEquals(frozenOutputLayerParams, network.getLayer(3).params());
|
|
|
|
}
|
|
|
|
@Test
|
|
public void testComputationGraphFrozenLayerParamsAfterBackprop() {
|
|
Nd4j.getRandom().setSeed(12345);
|
|
|
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
|
String frozenBranchName = "B1-";
|
|
String unfrozenBranchName = "B2-";
|
|
|
|
String initialLayer = "initial";
|
|
|
|
String frozenBranchUnfrozenLayer0 = frozenBranchName + "0";
|
|
String frozenBranchFrozenLayer1 = frozenBranchName + "1";
|
|
String frozenBranchFrozenLayer2 = frozenBranchName + "2";
|
|
String frozenBranchOutput = frozenBranchName + "Output";
|
|
|
|
|
|
String unfrozenLayer0 = unfrozenBranchName + "0";
|
|
String unfrozenLayer1 = unfrozenBranchName + "1";
|
|
String unfrozenBranch2 = unfrozenBranchName + "Output";
|
|
|
|
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder()
|
|
.updater(new Sgd(2.0))
|
|
.seed(12345)
|
|
.graphBuilder()
|
|
.addInputs("input")
|
|
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
|
.addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
|
|
.addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
|
|
.addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
|
|
.addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
|
.addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
|
.addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
|
.addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
|
.addLayer(frozenBranchOutput,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
|
|
.setOutputs(frozenBranchOutput)
|
|
.build();
|
|
|
|
ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
|
|
computationGraph.init();
|
|
INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
|
|
INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
|
|
INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
|
|
INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup();
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
computationGraph.fit(randomData);
|
|
}
|
|
|
|
assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
|
|
assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params());
|
|
assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params());
|
|
assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params());
|
|
|
|
}
|
|
|
|
/**
|
|
* Frozen layer should have same results as a layer with Sgd updater with learning rate set to 0
|
|
*/
|
|
@Test
|
|
public void testFrozenLayerVsSgd() {
|
|
Nd4j.getRandom().setSeed(12345);
|
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
|
|
|
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
|
|
.seed(12345)
|
|
.weightInit(WeightInit.XAVIER)
|
|
.updater(new Sgd(2))
|
|
.list()
|
|
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
|
.layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
|
|
.layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
|
|
.layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
|
|
.build();
|
|
|
|
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
|
|
.seed(12345)
|
|
.weightInit(WeightInit.XAVIER)
|
|
.updater(new Sgd(2))
|
|
.list()
|
|
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
|
.layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
|
|
.layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
|
|
.layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
|
|
.build();
|
|
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
|
frozenNetwork.init();
|
|
INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup();
|
|
INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup();
|
|
INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup();
|
|
INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup();
|
|
|
|
MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
|
|
sgdNetwork.init();
|
|
INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup();
|
|
INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup();
|
|
INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup();
|
|
INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup();
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
frozenNetwork.fit(randomData);
|
|
}
|
|
for (int i = 0; i < 100; i++) {
|
|
sgdNetwork.fit(randomData);
|
|
}
|
|
|
|
assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params());
|
|
assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params());
|
|
assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params());
|
|
assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params());
|
|
|
|
}
|
|
|
|
@Test
|
|
public void testComputationGraphVsSgd() {
|
|
Nd4j.getRandom().setSeed(12345);
|
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
|
String frozenBranchName = "B1-";
|
|
String unfrozenBranchName = "B2-";
|
|
|
|
String initialLayer = "initial";
|
|
|
|
String frozenBranchUnfrozenLayer0 = frozenBranchName + "0";
|
|
String frozenBranchFrozenLayer1 = frozenBranchName + "1";
|
|
String frozenBranchFrozenLayer2 = frozenBranchName + "2";
|
|
String frozenBranchOutput = frozenBranchName + "Output";
|
|
|
|
|
|
String unfrozenLayer0 = unfrozenBranchName + "0";
|
|
String unfrozenLayer1 = unfrozenBranchName + "1";
|
|
String unfrozenBranch2 = unfrozenBranchName + "Output";
|
|
|
|
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder()
|
|
.updater(new Sgd(2.0))
|
|
.seed(12345)
|
|
.graphBuilder()
|
|
.addInputs("input")
|
|
.addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
|
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
|
|
.addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
|
|
.addLayer(frozenBranchFrozenLayer2,
|
|
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
|
|
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
|
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
|
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
|
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
|
.addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
|
|
.setOutputs(frozenBranchOutput)
|
|
.build();
|
|
|
|
ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder()
|
|
.updater(new Sgd(2.0))
|
|
.seed(12345)
|
|
.graphBuilder()
|
|
.addInputs("input")
|
|
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
|
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
|
|
.addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
|
|
.addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
|
|
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
|
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
|
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
|
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
|
.addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
|
|
.setOutputs(frozenBranchOutput)
|
|
.build();
|
|
|
|
ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
|
|
frozenComputationGraph.init();
|
|
INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
|
|
INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
|
|
INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
|
|
INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup();
|
|
|
|
ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
|
|
sgdComputationGraph.init();
|
|
INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
|
|
INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
|
|
INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
|
|
INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup();
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
frozenComputationGraph.fit(randomData);
|
|
}
|
|
for (int i = 0; i < 100; i++) {
|
|
sgdComputationGraph.fit(randomData);
|
|
}
|
|
|
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
|
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params());
|
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params());
|
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params());
|
|
|
|
}
|
|
|
|
|
|
}
|