2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
* ******************************************************************************
|
|
|
|
* *
|
|
|
|
* *
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
* * under the License.
|
|
|
|
* *
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
* *****************************************************************************
|
|
|
|
*/
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
package org.deeplearning4j.nn.layers;
|
|
|
|
|
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
|
|
import org.deeplearning4j.TestUtils;
|
|
|
|
import org.deeplearning4j.nn.conf.*;
|
|
|
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.*;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
2021-03-15 13:02:01 +09:00
|
|
|
import org.junit.jupiter.api.Test;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.activations.Activation;
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
import org.nd4j.linalg.dataset.DataSet;
|
|
|
|
import org.nd4j.linalg.factory.Nd4j;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2021-03-15 13:02:01 +09:00
|
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
public class CacheModeTest extends BaseDL4JTest {
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testConvCacheModeSimple(){
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
MultiLayerConfiguration conf1 = getConf(CacheMode.NONE);
|
|
|
|
MultiLayerConfiguration conf2 = getConf(CacheMode.DEVICE);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
|
|
|
net1.init();
|
|
|
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
|
|
|
net2.init();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
INDArray in = Nd4j.rand(3, 28*28);
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray labels = TestUtils.randomOneHot(3, 10);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray out1 = net1.output(in);
|
|
|
|
INDArray out2 = net2.output(in);
|
|
|
|
assertEquals(out1, out2);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
net1.fit(in, labels);
|
|
|
|
net2.fit(in, labels);
|
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
}
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
private static MultiLayerConfiguration getConf(CacheMode cacheMode){
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
|
|
.activation(Activation.TANH)
|
|
|
|
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
|
|
|
|
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
|
|
|
|
.seed(12345)
|
|
|
|
.cacheMode(cacheMode)
|
|
|
|
.list()
|
|
|
|
.layer(new ConvolutionLayer.Builder().nOut(3).build())
|
|
|
|
.layer(new ConvolutionLayer.Builder().nOut(3).build())
|
|
|
|
.layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
|
|
|
|
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
|
|
|
.build();
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
return conf;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testLSTMCacheModeSimple(){
|
|
|
|
|
|
|
|
for(boolean graves : new boolean[]{true, false}) {
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
MultiLayerConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves);
|
|
|
|
MultiLayerConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
|
|
|
net1.init();
|
|
|
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
|
|
|
net2.init();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
INDArray in = Nd4j.rand(new int[]{3, 3, 10});
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray labels = TestUtils.randomOneHotTimeSeries(3, 10, 10);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray out1 = net1.output(in);
|
|
|
|
INDArray out2 = net2.output(in);
|
|
|
|
assertEquals(out1, out2);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
net1.fit(in, labels);
|
|
|
|
net2.fit(in, labels);
|
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves){
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
|
|
.activation(Activation.TANH)
|
|
|
|
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
|
|
|
|
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
|
|
|
|
.seed(12345)
|
|
|
|
.cacheMode(cacheMode)
|
|
|
|
.list()
|
|
|
|
.layer(graves ?
|
|
|
|
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
|
|
|
|
new LSTM.Builder().nIn(3).nOut(3).build())
|
|
|
|
.layer(graves ?
|
|
|
|
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
|
|
|
|
new LSTM.Builder().nIn(3).nOut(3).build())
|
|
|
|
.layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
|
|
|
|
.build();
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
return conf;
|
|
|
|
}
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testConvCacheModeSimpleCG(){
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
ComputationGraphConfiguration conf1 = getConfCG(CacheMode.NONE);
|
|
|
|
ComputationGraphConfiguration conf2 = getConfCG(CacheMode.DEVICE);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
ComputationGraph net1 = new ComputationGraph(conf1);
|
|
|
|
net1.init();
|
|
|
|
ComputationGraph net2 = new ComputationGraph(conf2);
|
|
|
|
net2.init();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
INDArray in = Nd4j.rand(3, 28*28);
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray labels = TestUtils.randomOneHot(3, 10);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray out1 = net1.outputSingle(in);
|
|
|
|
INDArray out2 = net2.outputSingle(in);
|
|
|
|
assertEquals(out1, out2);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
net1.fit(new DataSet(in, labels));
|
|
|
|
net2.fit(new DataSet(in, labels));
|
|
|
|
assertEquals(net1.params(), net2.params());
|
|
|
|
}
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){
|
|
|
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
|
|
.activation(Activation.TANH)
|
|
|
|
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
|
|
|
|
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
|
|
|
|
.seed(12345)
|
|
|
|
.cacheMode(cacheMode)
|
|
|
|
.graphBuilder()
|
|
|
|
.addInputs("in")
|
|
|
|
.layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in")
|
|
|
|
.layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0")
|
|
|
|
.layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1")
|
|
|
|
.setOutputs("2")
|
|
|
|
.setInputTypes(InputType.convolutionalFlat(28, 28, 1))
|
|
|
|
.build();
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
return conf;
|
|
|
|
}
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|