146 lines
6.0 KiB
Java
146 lines
6.0 KiB
Java
/*
|
|
*
|
|
* ******************************************************************************
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*
|
|
*/
|
|
|
|
package net.brutex.gan;
|
|
|
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
import org.nd4j.linalg.learning.config.IUpdater;
|
|
import org.nd4j.linalg.learning.config.Sgd;
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
import javax.swing.*;
|
|
|
|
|
|
/**
|
|
* Relatively small GAN example using only Dense layers with dropout to generate handwritten
|
|
* digits from MNIST data.
|
|
*/
|
|
public class MnistSimpleGAN {
|
|
|
|
private static final int LATENT_DIM = 100;
|
|
|
|
private static final double LEARNING_RATE = 0.0002;
|
|
private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build();
|
|
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
|
|
|
|
|
public static MultiLayerNetwork getGenerator() {
|
|
MultiLayerConfiguration genConf = new NeuralNetConfiguration.Builder()
|
|
.weightInit(WeightInit.XAVIER)
|
|
.activation(Activation.IDENTITY)
|
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
|
.gradientNormalizationThreshold(100)
|
|
.list()
|
|
.layer(new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build())
|
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
|
.layer(new DenseLayer.Builder().nIn(256).nOut(512).build())
|
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
|
.layer(new DenseLayer.Builder().nIn(512).nOut(1024).build())
|
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
|
.layer(new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build())
|
|
.build();
|
|
return new MultiLayerNetwork(genConf);
|
|
}
|
|
|
|
|
|
public static MultiLayerNetwork getDiscriminator(IUpdater updater) {
|
|
MultiLayerConfiguration discConf = new NeuralNetConfiguration.Builder()
|
|
.seed(42)
|
|
.updater(updater)
|
|
.weightInit(WeightInit.XAVIER)
|
|
.activation(Activation.IDENTITY)
|
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
|
.gradientNormalizationThreshold(100)
|
|
.list()
|
|
.layer(new DenseLayer.Builder().nIn(784).nOut(1024).updater(updater).build())
|
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
|
.layer(new DropoutLayer.Builder(1 - 0.5).build())
|
|
.layer(new DenseLayer.Builder().nIn(1024).nOut(512).updater(updater).build())
|
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
|
.layer(new DropoutLayer.Builder(1 - 0.5).build())
|
|
.layer(new DenseLayer.Builder().nIn(512).nOut(256).updater(updater).build())
|
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
|
.layer(new DropoutLayer.Builder(1 - 0.5).build())
|
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1)
|
|
.activation(Activation.SIGMOID).updater(updater).build())
|
|
.build();
|
|
|
|
return new MultiLayerNetwork(discConf);
|
|
}
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
GAN gan = new GAN.Builder()
|
|
.generator(MnistSimpleGAN::getGenerator)
|
|
.discriminator(MnistSimpleGAN::getDiscriminator)
|
|
.latentDimension(LATENT_DIM)
|
|
.seed(42)
|
|
.updater(UPDATER)
|
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
|
.gradientNormalizationThreshold(100)
|
|
.build();
|
|
|
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
|
|
|
int batchSize = 128;
|
|
MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42);
|
|
|
|
|
|
// Sample from latent space once to visualize progress on image generation.
|
|
int numSamples = 9;
|
|
JFrame frame = GANVisualizationUtils.initFrame();
|
|
JPanel panel = GANVisualizationUtils.initPanel(frame, numSamples);
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
trainData.reset();
|
|
int j = 0;
|
|
while (trainData.hasNext()) {
|
|
gan.fit(trainData.next());
|
|
//gan.fit(trainData, 1);
|
|
|
|
if (j % 10 == 0) {
|
|
INDArray fakeIn = Nd4j.rand(new int[]{batchSize, LATENT_DIM});
|
|
System.out.println("Epoch " + (i + 1) + " Iteration " + j + " Visualizing...");
|
|
INDArray[] samples = new INDArray[numSamples];
|
|
for (int k = 0; k < numSamples; k++) {
|
|
INDArray input = fakeIn.getRow(k);
|
|
samples[k] = gan.getGenerator().output(input, false);
|
|
}
|
|
GANVisualizationUtils.visualize(samples, frame, panel);
|
|
}
|
|
j++;
|
|
}
|
|
}
|
|
}
|
|
} |