182 lines
7.9 KiB
Java
182 lines
7.9 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.deeplearning4j.nn.graph;
|
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
|
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
|
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
import org.junit.Test;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
|
|
import java.util.Arrays;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
|
|
import static org.junit.Assert.assertEquals;
|
|
import static org.junit.Assert.assertNotEquals;
|
|
|
|
public class TestCompGraphUnsupervised extends BaseDL4JTest {
|
|
|
|
@Override
|
|
public DataType getDataType() {
|
|
return DataType.FLOAT;
|
|
}
|
|
|
|
@Test
|
|
public void testVAE() throws Exception {
|
|
|
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
|
|
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
.seed(12345)
|
|
.updater(new Adam(1e-3))
|
|
.weightInit(WeightInit.XAVIER)
|
|
.inferenceWorkspaceMode(wsm)
|
|
.trainingWorkspaceMode(wsm)
|
|
.graphBuilder()
|
|
.addInputs("in")
|
|
.addLayer("vae1", new VariationalAutoencoder.Builder()
|
|
.nIn(784)
|
|
.nOut(32)
|
|
.encoderLayerSizes(16)
|
|
.decoderLayerSizes(16)
|
|
.activation(Activation.TANH)
|
|
.pzxActivationFunction(Activation.SIGMOID)
|
|
.reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID))
|
|
.build(), "in")
|
|
.addLayer("vae2", new VariationalAutoencoder.Builder()
|
|
.nIn(32)
|
|
.nOut(8)
|
|
.encoderLayerSizes(16)
|
|
.decoderLayerSizes(16)
|
|
.activation(Activation.TANH)
|
|
.pzxActivationFunction(Activation.SIGMOID)
|
|
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.TANH))
|
|
.build(), "vae1")
|
|
.setOutputs("vae2")
|
|
.build();
|
|
|
|
ComputationGraph cg = new ComputationGraph(conf);
|
|
cg.init();
|
|
|
|
DataSetIterator ds = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(8, true, 12345), 3);
|
|
|
|
Map<String,INDArray> paramsBefore = new HashMap<>();
|
|
|
|
//Pretrain first layer
|
|
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){
|
|
paramsBefore.put(e.getKey(), e.getValue().dup());
|
|
}
|
|
cg.pretrainLayer("vae1", ds);
|
|
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){
|
|
if(e.getKey().startsWith("vae1")){
|
|
assertNotEquals(paramsBefore.get(e.getKey()), e.getValue());
|
|
} else {
|
|
assertEquals(paramsBefore.get(e.getKey()), e.getValue());
|
|
}
|
|
}
|
|
|
|
int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0);
|
|
assertEquals(0, count);
|
|
|
|
|
|
//Pretrain second layer
|
|
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){
|
|
paramsBefore.put(e.getKey(), e.getValue().dup());
|
|
}
|
|
cg.pretrainLayer("vae2", ds);
|
|
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){
|
|
if(e.getKey().startsWith("vae2")){
|
|
assertNotEquals(paramsBefore.get(e.getKey()), e.getValue());
|
|
} else {
|
|
assertEquals(paramsBefore.get(e.getKey()), e.getValue());
|
|
}
|
|
}
|
|
|
|
count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0);
|
|
assertEquals(0, count);
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void compareImplementations() throws Exception {
|
|
|
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
|
|
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
|
.seed(12345)
|
|
.updater(new Adam(1e-3))
|
|
.weightInit(WeightInit.XAVIER)
|
|
.inferenceWorkspaceMode(wsm)
|
|
.trainingWorkspaceMode(wsm)
|
|
.list()
|
|
.layer(new VariationalAutoencoder.Builder()
|
|
.nIn(784)
|
|
.nOut(32)
|
|
.encoderLayerSizes(16)
|
|
.decoderLayerSizes(16)
|
|
.activation(Activation.TANH)
|
|
.pzxActivationFunction(Activation.SIGMOID)
|
|
.reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID))
|
|
.build())
|
|
.layer(new VariationalAutoencoder.Builder()
|
|
.nIn(32)
|
|
.nOut(8)
|
|
.encoderLayerSizes(16)
|
|
.decoderLayerSizes(16)
|
|
.activation(Activation.TANH)
|
|
.pzxActivationFunction(Activation.SIGMOID)
|
|
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.TANH))
|
|
.build())
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf2);
|
|
net.init();
|
|
|
|
ComputationGraph cg = net.toComputationGraph();
|
|
cg.getConfiguration().setInferenceWorkspaceMode(wsm);
|
|
cg.getConfiguration().setTrainingWorkspaceMode(wsm);
|
|
DataSetIterator ds = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(1, true, 12345), 1);
|
|
Nd4j.getRandom().setSeed(12345);
|
|
net.pretrainLayer(0, ds);
|
|
|
|
ds = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(1, true, 12345), 1);
|
|
Nd4j.getRandom().setSeed(12345);
|
|
cg.pretrainLayer("0", ds);
|
|
|
|
assertEquals(net.params(), cg.params());
|
|
}
|
|
}
|
|
|
|
}
|