174 lines
6.5 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
2021-03-15 13:02:01 +09:00
import org.junit.jupiter.api.Test;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
2021-03-15 13:02:01 +09:00
import static org.junit.jupiter.api.Assertions.assertEquals;
2019-06-06 15:21:15 +03:00
public class CacheModeTest extends BaseDL4JTest {
2019-06-06 15:21:15 +03:00
@Test
public void testConvCacheModeSimple(){
2019-06-06 15:21:15 +03:00
MultiLayerConfiguration conf1 = getConf(CacheMode.NONE);
MultiLayerConfiguration conf2 = getConf(CacheMode.DEVICE);
2019-06-06 15:21:15 +03:00
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
INDArray in = Nd4j.rand(3, 28*28);
2019-06-06 15:21:15 +03:00
INDArray labels = TestUtils.randomOneHot(3, 10);
2019-06-06 15:21:15 +03:00
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
2019-06-06 15:21:15 +03:00
assertEquals(net1.params(), net2.params());
net1.fit(in, labels);
net2.fit(in, labels);
assertEquals(net1.params(), net2.params());
}
private static MultiLayerConfiguration getConf(CacheMode cacheMode){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.list()
.layer(new ConvolutionLayer.Builder().nOut(3).build())
.layer(new ConvolutionLayer.Builder().nOut(3).build())
.layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
2019-06-06 15:21:15 +03:00
return conf;
}
@Test
public void testLSTMCacheModeSimple(){
for(boolean graves : new boolean[]{true, false}) {
2019-06-06 15:21:15 +03:00
MultiLayerConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves);
MultiLayerConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves);
2019-06-06 15:21:15 +03:00
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
INDArray in = Nd4j.rand(new int[]{3, 3, 10});
2019-06-06 15:21:15 +03:00
INDArray labels = TestUtils.randomOneHotTimeSeries(3, 10, 10);
2019-06-06 15:21:15 +03:00
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
2019-06-06 15:21:15 +03:00
assertEquals(net1.params(), net2.params());
net1.fit(in, labels);
net2.fit(in, labels);
assertEquals(net1.params(), net2.params());
}
}
private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.list()
.layer(graves ?
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
new LSTM.Builder().nIn(3).nOut(3).build())
.layer(graves ?
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
.build();
2019-06-06 15:21:15 +03:00
return conf;
}
2019-06-06 15:21:15 +03:00
@Test
public void testConvCacheModeSimpleCG(){
2019-06-06 15:21:15 +03:00
ComputationGraphConfiguration conf1 = getConfCG(CacheMode.NONE);
ComputationGraphConfiguration conf2 = getConfCG(CacheMode.DEVICE);
2019-06-06 15:21:15 +03:00
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
INDArray in = Nd4j.rand(3, 28*28);
2019-06-06 15:21:15 +03:00
INDArray labels = TestUtils.randomOneHot(3, 10);
2019-06-06 15:21:15 +03:00
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
assertEquals(out1, out2);
2019-06-06 15:21:15 +03:00
assertEquals(net1.params(), net2.params());
net1.fit(new DataSet(in, labels));
net2.fit(new DataSet(in, labels));
assertEquals(net1.params(), net2.params());
}
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.graphBuilder()
.addInputs("in")
.layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in")
.layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0")
.layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1")
.setOutputs("2")
.setInputTypes(InputType.convolutionalFlat(28, 28, 1))
.build();
2019-06-06 15:21:15 +03:00
return conf;
}
2019-06-06 15:21:15 +03:00
}