gan example

Signed-off-by: brian <brian@brutex.de>
enhance-build-infrastructure
Brian Rosenberger 2023-08-07 10:32:39 +02:00
parent 1b3338f809
commit dd151aec3f
20 changed files with 1267 additions and 838 deletions

View File

@ -1,473 +1,261 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan; package net.brutex.gan;
import static net.brutex.ai.dnn.api.NN.dense;
import java.awt.*; import java.awt.*;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Random; import javax.swing.*;
import java.util.UUID;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.WindowConstants;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Slf4j
public class App { public class App {
private static final double LEARNING_RATE = 0.000002; private static final double LEARNING_RATE = 0.002;
private static final double GRADIENT_THRESHOLD = 100.0; private static final double GRADIENT_THRESHOLD = 100.0;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
private static final int BATCHSIZE = 128;
private static JFrame frame;
private static JPanel panel;
private static final int X_DIM = 20 ; private static LayerConfiguration[] genLayers() {
private static final int Y_DIM = 20; return new LayerConfiguration[] {
private static final int CHANNELS = 1; dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
private static final int batchSize = 1; ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
private static final int INPUT = 10; dense().nIn(256).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
};
}
private static final int OUTPUT_PER_PANEL = 16; /**
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
*
* @return config
*/
private static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(genLayers())
.name("generator")
.build();
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS; return conf;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); }
private static JFrame frame; private static LayerConfiguration[] disLayers() {
private static JFrame frame2; return new LayerConfiguration[]{
private static JPanel panel; dense().nIn(784).nOut(1024).build(),
private static JPanel panel2; ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(1024).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(512).nOut(256).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
};
}
private static final String OUTPUT_DIR = "C:/temp/output/"; private static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(disLayers())
.name("discriminator")
.build();
private static LayerConfiguration[] genLayers() { return conf;
return new LayerConfiguration[] { }
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(Activation.LEAKYRELU).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), private static NeuralNetConfiguration gan() {
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), LayerConfiguration[] genLayers = genLayers();
DropoutLayer.builder(1 - 0.5).build(), LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), .map((layer) -> {
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
}; .seed(42)
} .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(layers)
.name("GAN")
.build();
/** return conf;
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image. }
*
* @return config
*/
private static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInit(WeightInit.XAVIER)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(genLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
.build();
((NeuralNetConfiguration) conf).init();
return conf; @Test
} public void runTest() throws Exception {
App.main(null);
}
public static void main(String... args) throws Exception {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
private static LayerConfiguration[] disLayers() { MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
return new LayerConfiguration[]{
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build() MultiLayerNetwork gen = new MultiLayerNetwork(generator());
}; MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
} MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init();
dis.init();
gan.init();
private static NeuralNetConfiguration discriminator() { copyParams(gen, dis, gan);
NeuralNetConfiguration conf = gen.addTrainingListeners(new PerformanceListener(10, true));
NeuralNetConfiguration.builder() dis.addTrainingListeners(new PerformanceListener(10, true));
.seed(42) gan.addTrainingListeners(new PerformanceListener(10, true));
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(disLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build();
((NeuralNetConfiguration) conf).init();
return conf; trainData.reset();
}
private static NeuralNetConfiguration gan() { int j = 0;
LayerConfiguration[] genLayers = genLayers(); for (int i = 0; i < 50; i++) {
LayerConfiguration[] disLayers = Arrays.stream(disLayers()) while (trainData.hasNext()) {
.map((layer) -> { j++;
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() // generate data
.seed(42) INDArray real = trainData.next().getFeatures().muli(2).subi(1);
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) int batchSize = (int) real.shape()[0];
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold( 100 ) INDArray fakeIn = Nd4j.rand(batchSize, 100);
//.weightInitFn( new WeightInitXavier() ) //this is internal INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightInit( WeightInit.XAVIER) DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
//.activationFn( new ActivationIdentity()) //this is internal DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
.activation( Activation.IDENTITY )
.layersFromArray( layers ) DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.dataType(DataType.FLOAT) dis.fit(data);
.build(); dis.fit(data);
((NeuralNetConfiguration) conf).init();
return conf; // Update the discriminator in the GAN network
} updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
@Test if (j % 10 == 1) {
public void runTest() throws Exception { System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
if(! log.isDebugEnabled()) { INDArray[] samples = new INDArray[9];
log.info("Logging is not set to DEBUG"); DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
}
else {
log.info("Logging is set to DEBUG");
}
main();
}
public static void main(String... args) throws Exception { for (int k = 0; k < 9; k++) {
INDArray input = fakeSet2.get(k).getFeatures();
//samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m "); }
Nd4j.getMemoryManager().setAutoGcWindow(500); visualize(samples);
}
//MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45); }
//FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS()); trainData.reset();
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS()); // Copy the GANs generator to gen.
//updateGen(gen, gan);
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
ImageTransform tr = new PipelineImageTransform.Builder()
//.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init(); log.debug("Generator network: {}", gen);
dis.init(); log.debug("Discriminator network: {}", dis);
gan.init(); log.info("Complete GAN network: {}", gan);
copyParams(gen, dis, gan);
//gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis"));
//System.out.println(gan.toString());
//gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
//trainData.reset();
int j = 0;
for (int i = 0; i < 51; i++) { //epoch
while (trainData.hasNext()) {
j++;
DataSet next = trainData.next();
// generate data
INDArray real = next.getFeatures();//.div(255f);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
} }
//if(i%20 == 0) {
frame2 = visualize(new INDArray[]{real}, batchSize,
frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
//}
real.divi(255f);
// int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
//log.info("real has {} items.", real.length());
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
//dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
//Visualize and reporting
if (j % 10 == 1) {
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
for (int k = 0; k < samples.length; k++) {
//INDArray input = fakeSet2.get(k).getFeatures();
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
INDArray input = fakeSet2.get(k).getFeatures();
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
//samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
//samples[k] =
samples[k].addi(1f).divi(2f).muli(255f);
}
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
}
}
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
// Copy the GANs generator to gen. // Copy the GANs generator to gen.
updateGen(gen, gan); updateGen(gen, gan);
gen.save(new File("mnist-mlp-generator.dlj")); gen.save(new File("mnist-mlp-generator.dlj"));
} }
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) {
} if (i < genLayerCount) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { } else {
int genLayerCount = gen.getLayers().length; dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
for (int i = 0; i < gan.getLayers().length; i++) { }
if (i < genLayerCount) {
if(gan.getLayer(i).getParams() != null)
gan.getLayer(i).setParams(gen.getLayer(i).getParams());
} else {
if(gan.getLayer(i).getParams() != null)
gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
}
}
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
if (isOrig) {
frame.setTitle("Viz Original");
} else {
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8));
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true);
frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
return frame;
}
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS>1) {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = X_DIM * Y_DIM;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
for (int x = 0; x < X_DIM; x++) { //step through the rows y
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
pxl++; //next item in INDArray
} }
}
} }
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled);
} private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
private static void saveImage(Image image, int batchElement, boolean isOrig) { private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
try { private static void visualize(INDArray[] samples) {
// Save the images to disk if (frame == null) {
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png"); frame = new JFrame();
frame.setTitle("Viz");
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
log.debug("Images saved successfully."); panel = new JPanel();
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage()); panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
}
panel.removeAll();
for (INDArray sample : samples) {
panel.add(getImage(sample));
}
frame.revalidate();
frame.pack();
}
private static JLabel getImage(INDArray tensor) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled);
} }
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
} }

View File

@ -0,0 +1,343 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.List;
import javax.imageio.ImageIO;
import javax.swing.*;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.*;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
@Slf4j
public class App2 {
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
static final float COLORSPACE = 255f;
static final int DIMENSIONS = 28;
static final int CHANNELS = 1;
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
final int OUTPUT_PER_PANEL = 10;
final boolean BIAS = true;
static final int BATCHSIZE=128;
private JFrame frame2, frame;
static final String OUTPUT_DIR = "d:/out/";
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
@Test
void runTest() throws IOException {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
ImageTransform tr = new PipelineImageTransform.Builder()
.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
.map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
NeuralNetConfiguration netConfiguration =
NeuralNetConfiguration.builder()
.name("GAN")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.updater(App2Config.UPDATER)
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
dis.init(); log.debug("Discriminator network: {}", dis);
gen.init(); log.debug("Generator network: {}", gen);
gan.init(); log.debug("GAN network: {}", gan);
log.info("Generator Summary:\n{}", gen.summary());
log.info("GAN Summary:\n{}", gan.summary());
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
int j = 0;
for (int i = 0; i < 51; i++) { //epoch
while (trainData.hasNext()) {
j++;
DataSet next = trainData.next();
// generate data
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
//if(i%20 == 0) {
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
//}
//real.divi(255f);
// int batchSize = (int) real.shape()[0];
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
// when generator has TANH as activation - value range is -1 to 1
// when generator has SIGMOID, then range is 0 to 1
fake.addi(1f).divi(2f);
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
//Visualize and reporting
if (j % 10 == 1) {
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
for (int k = 0; k < samples.length; k++) {
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
INDArray input = fakeSet2.get(k).getFeatures();
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
input = input.reshape(1, App2Config.INPUT);
//samples[k] = gen.output(input, false);
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
//samples[k] =
//samples[k].muli(255f);
}
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
}
}
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
// Copy the GANs generator to gen.
updateGen(gen, gan);
log.info("Updated GAN's generator from gen.");
gen.save(new File("mnist-mlp-generator.dlj"));
}
}
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
if (isOrig) {
frame.setTitle("Viz Original");
} else {
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8));
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true);
frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
return frame;
}
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS >1) {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = DIMENSIONS * DIMENSIONS;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
bi.getRaster().setSample(x, y, c, f_pxl);
pxl++; //next item in INDArray
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled);
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage;
if (CHANNELS > 1) {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
} else {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
}
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
}

View File

@ -0,0 +1,176 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import static net.brutex.ai.dnn.api.NN.*;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class App2Config {
public static final int INPUT = 100;
public static final int X_DIM = 28;
public static final int y_DIM = 28;
public static final int CHANNELS = 1;
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
static LayerConfiguration[] genLayerConfig() {
return new LayerConfiguration[] {
/*
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
ActivationLayer.builder().activation(Activation.RELU).build(), /*
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(1,1)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().nOut(CHANNELS).build(),
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(2,2)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
// DropoutLayer.builder(0.001).build(),
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(256).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
};
}
static LayerConfiguration[] disLayerConfig() {
return new LayerConfiguration[] {/*
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
.build(),
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
.build(),
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
BatchNormalization.builder().build(),
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SIGMOID)
.build()
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
.activation(Activation.SIGMOID)
.build() */
dense().nIn(784).nOut(1024).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(1024).nOut(512).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(512).nOut(256).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
};
}
static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("generator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER)
.weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(App2Config.genLayerConfig())
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
.build();
conf.init();
return conf;
}
static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("discriminator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER)
.weightInit(WeightInit.XAVIER)
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(disLayerConfig())
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
conf.init();
return conf;
}
}

View File

@ -43,223 +43,255 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.impo
@Slf4j @Slf4j
public class KerasSequentialModel extends KerasModel { public class KerasSequentialModel extends KerasModel {
/** /**
* (Recommended) Builder-pattern constructor for Sequential model. * (Recommended) Builder-pattern constructor for Sequential model.
* *
* @param modelBuilder builder object * @param modelBuilder builder object
* @throws IOException I/O exception * @throws IOException I/O exception
* @throws InvalidKerasConfigurationException Invalid Keras configuration * @throws InvalidKerasConfigurationException Invalid Keras configuration
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
*/ */
public KerasSequentialModel(KerasModelBuilder modelBuilder) public KerasSequentialModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { throws UnsupportedKerasConfigurationException,
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(), IOException,
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(), InvalidKerasConfigurationException {
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape()); this(
modelBuilder.getModelJson(),
modelBuilder.getModelYaml(),
modelBuilder.getWeightsArchive(),
modelBuilder.getWeightsRoot(),
modelBuilder.getTrainingJson(),
modelBuilder.getTrainingArchive(),
modelBuilder.isEnforceTrainingConfig(),
modelBuilder.getInputShape());
}
/**
* (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
* training configuration (JSON), weights, and "training mode" boolean indicator. When built in
* training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
* Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
* otherwise ignored.
*
* @param modelJson model configuration JSON string
* @param modelYaml model configuration YAML string
* @param trainingJson training configuration JSON string
* @throws IOException I/O exception
*/
public KerasSequentialModel(
String modelJson,
String modelYaml,
Hdf5Archive weightsArchive,
String weightsRoot,
String trainingJson,
Hdf5Archive trainingArchive,
boolean enforceTrainingConfig,
int[] inputShape)
throws IOException,
InvalidKerasConfigurationException,
UnsupportedKerasConfigurationException {
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
this.enforceTrainingConfig = enforceTrainingConfig;
/* Determine model configuration type. */
if (!modelConfig.containsKey(config.getFieldClassName()))
throw new InvalidKerasConfigurationException(
"Could not determine Keras model class (no "
+ config.getFieldClassName()
+ " field found)");
this.className = (String) modelConfig.get(config.getFieldClassName());
if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException(
"Model class name must be "
+ config.getFieldClassNameSequential()
+ " (found "
+ this.className
+ ")");
/* Process layer configurations. */
if (!modelConfig.containsKey(config.getModelFieldConfig()))
throw new InvalidKerasConfigurationException(
"Could not find layer configurations (no "
+ config.getModelFieldConfig()
+ " field found)");
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
// For consistency
// "config" is now an object containing a "name" and "layers", the latter contain the same data
// as before.
// This change only affects Sequential models.
List<Object> layerList;
try {
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
} catch (Exception e) {
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
layerList = (List<Object>) layerMap.get("layers");
} }
/** Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
* (Not recommended) Constructor for Sequential model from model configuration this.layers = layerPair.getFirst();
* (JSON or YAML), training configuration (JSON), weights, and "training mode" this.layersOrdered = layerPair.getSecond();
* boolean indicator. When built in training mode, certain unsupported configurations
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
* will generate warnings but will be otherwise ignored.
*
* @param modelJson model configuration JSON string
* @param modelYaml model configuration YAML string
* @param trainingJson training configuration JSON string
* @throws IOException I/O exception
*/
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
int[] inputShape)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml); KerasLayer inputLayer;
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config); if (this.layersOrdered.get(0) instanceof KerasInput) {
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config); inputLayer = this.layersOrdered.get(0);
this.enforceTrainingConfig = enforceTrainingConfig; } else {
/* Add placeholder input layer and update lists of input and output layers. */
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
Preconditions.checkState(
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
inputLayer = new KerasInput("input1", firstLayerInputShape);
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
this.layers.put(inputLayer.getName(), inputLayer);
this.layersOrdered.add(0, inputLayer);
}
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
this.outputLayerNames =
new ArrayList<>(
Collections.singletonList(
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
/* Determine model configuration type. */ /* Update each layer's inbound layer list to include (only) previous layer. */
if (!modelConfig.containsKey(config.getFieldClassName())) KerasLayer prevLayer = null;
throw new InvalidKerasConfigurationException( for (KerasLayer layer : this.layersOrdered) {
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)"); if (prevLayer != null)
this.className = (String) modelConfig.get(config.getFieldClassName()); layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
if (!this.className.equals(config.getFieldClassNameSequential())) prevLayer = layer;
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
+ " (found " + this.className + ")");
/* Process layer configurations. */
if (!modelConfig.containsKey(config.getModelFieldConfig()))
throw new InvalidKerasConfigurationException(
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
// This change only affects Sequential models.
List<Object> layerList;
try {
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
} catch (Exception e) {
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
layerList = (List<Object>) layerMap.get("layers");
}
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
prepareLayers(layerList);
this.layers = layerPair.getFirst();
this.layersOrdered = layerPair.getSecond();
KerasLayer inputLayer;
if (this.layersOrdered.get(0) instanceof KerasInput) {
inputLayer = this.layersOrdered.get(0);
} else {
/* Add placeholder input layer and update lists of input and output layers. */
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
inputLayer = new KerasInput("input1", firstLayerInputShape);
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
this.layers.put(inputLayer.getName(), inputLayer);
this.layersOrdered.add(0, inputLayer);
}
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
this.outputLayerNames = new ArrayList<>(
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
/* Update each layer's inbound layer list to include (only) previous layer. */
KerasLayer prevLayer = null;
for (KerasLayer layer : this.layersOrdered) {
if (prevLayer != null)
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
prevLayer = layer;
}
/* Import training configuration. */
if (enforceTrainingConfig) {
if (trainingJson != null)
importTrainingConfiguration(trainingJson);
else log.warn("If enforceTrainingConfig is true, a training " +
"configuration object has to be provided. Usually the only practical way to do this is to store" +
" your keras model with `model.save('model_path.h5'. If you store model config and weights" +
" separately no training configuration is attached.");
}
this.outputTypes = inferOutputTypes(inputShape);
if (weightsArchive != null)
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
} }
/** /* Import training configuration. */
* Default constructor if (enforceTrainingConfig) {
*/ if (trainingJson != null) importTrainingConfiguration(trainingJson);
public KerasSequentialModel() { else
super(); log.warn(
"If enforceTrainingConfig is true, a training "
+ "configuration object has to be provided. Usually the only practical way to do this is to store"
+ " your keras model with `model.save('model_path.h5'. If you store model config and weights"
+ " separately no training configuration is attached.");
} }
/** this.outputTypes = inferOutputTypes(inputShape);
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
*
* @return NeuralNetConfiguration
*/
public NeuralNetConfiguration getNeuralNetConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException(
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
if (this.inputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
if (this.outputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder(); if (weightsArchive != null)
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
}
if (optimizer != null) { /** Default constructor */
modelBuilder.updater(optimizer); public KerasSequentialModel() {
super();
}
/**
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
*
* @return NeuralNetConfiguration
*/
public NeuralNetConfiguration getNeuralNetConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException(
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
if (this.inputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
if (this.outputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
NeuralNetConfiguration.builder();
if (optimizer != null) {
modelBuilder.updater(optimizer);
}
// don't forcibly override for keras import
modelBuilder.overrideNinUponBuild(false);
/* Add layers one at a time. */
KerasLayer prevLayer = null;
int layerIndex = 0;
for (KerasLayer layer : this.layersOrdered) {
if (layer.isLayer()) {
int nbInbound = layer.getInboundLayerNames().size();
if (nbInbound != 1)
throw new InvalidKerasConfigurationException(
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
+ nbInbound
+ " for layer "
+ layer.getName()
+ ")");
if (prevLayer != null) {
InputType[] inputTypes = new InputType[1];
InputPreProcessor preprocessor;
if (prevLayer.isInputPreProcessor()) {
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
} else {
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
preprocessor = layer.getInputPreprocessor(inputTypes);
if (preprocessor != null) {
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
}
if (preprocessor != null) {
Map<Integer, InputPreProcessor> map = new HashMap<>();
map.put(layerIndex, preprocessor);
modelBuilder.inputPreProcessors(map);
}
} }
modelBuilder.layer(layerIndex++, layer.getLayer());
//don't forcibly override for keras import } else if (layer.getVertex() != null)
modelBuilder.overrideNinUponBuild(false); throw new InvalidKerasConfigurationException(
/* Add layers one at a time. */ "Cannot add vertex to NeuralNetConfiguration (class name "
KerasLayer prevLayer = null; + layer.getClassName()
int layerIndex = 0; + ", layer name "
for (KerasLayer layer : this.layersOrdered) { + layer.getName()
if (layer.isLayer()) { + ")");
int nbInbound = layer.getInboundLayerNames().size(); prevLayer = layer;
if (nbInbound != 1)
throw new InvalidKerasConfigurationException(
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
+ nbInbound + " for layer " + layer.getName() + ")");
if (prevLayer != null) {
InputType[] inputTypes = new InputType[1];
InputPreProcessor preprocessor;
if (prevLayer.isInputPreProcessor()) {
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
} else {
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
preprocessor = layer.getInputPreprocessor(inputTypes);
if(preprocessor != null) {
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
}
else
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
}
if (preprocessor != null)
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
}
modelBuilder.layer(layerIndex++, layer.getLayer());
} else if (layer.getVertex() != null)
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
+ layer.getClassName() + ", layer name " + layer.getName() + ")");
prevLayer = layer;
}
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
modelBuilder.backpropType(BackpropType.TruncatedBPTT)
.tbpttFwdLength(truncatedBPTT)
.tbpttBackLength(truncatedBPTT);
else
modelBuilder.backpropType(BackpropType.Standard);
NeuralNetConfiguration build = modelBuilder.build();
return build;
} }
/** /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
* Build a MultiLayerNetwork from this Keras Sequential model configuration. if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
* modelBuilder
* @return MultiLayerNetwork .backpropType(BackpropType.TruncatedBPTT)
*/ .tbpttFwdLength(truncatedBPTT)
public MultiLayerNetwork getMultiLayerNetwork() .tbpttBackLength(truncatedBPTT);
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { else modelBuilder.backpropType(BackpropType.Standard);
return getMultiLayerNetwork(true);
}
/** NeuralNetConfiguration build = modelBuilder.build();
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
* return build;
* @return MultiLayerNetwork }
*/
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights) /**
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { * Build a MultiLayerNetwork from this Keras Sequential model configuration.
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration()); *
model.init(); * @return MultiLayerNetwork
if (importWeights) */
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers); public MultiLayerNetwork getMultiLayerNetwork()
return model; throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
} return getMultiLayerNetwork(true);
}
/**
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
*
* @return MultiLayerNetwork
*/
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
model.init();
if (importWeights)
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
return model;
}
} }

View File

@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
/** /**
* A fluent API to configure and create artificial neural networks * A fluent API to configure and create artificial neural networks
@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB
public class NN { public class NN {
public static NeuralNetConfigurationBuilder<?, ?> net() { public static NeuralNetConfigurationBuilder<?, ?> nn() {
return NeuralNetConfiguration.builder(); return NeuralNetConfiguration.builder();
} }
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
} }

View File

@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
@Getter @Setter @NonNull @lombok.Builder.Default @Getter @Setter @NonNull @lombok.Builder.Default
protected BackpropType backpropType = BackpropType.Standard; protected BackpropType backpropType = BackpropType.Standard;
@Getter @lombok.Builder.Default @Getter @Setter @Singular
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>(); protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
/** /**
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated) * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
* @param processor what to use to preProcess the data. * @param processor what to use to preProcess the data.
* @return builder pattern * @return builder pattern
*/ */
public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) { //public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>(); // inputPreProcessors$value.put(layer, processor);
inputPreProcessors$value.put(layer, processor); // inputPreProcessors$set = true;
inputPreProcessors$set = true; // return self();
return self(); // }
}
/** /**
* Set layer at index * Set layer at index

View File

@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.*; import com.fasterxml.jackson.databind.*;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@NonNull @NonNull
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType); InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
if (inputPreProcessor != null) { if (inputPreProcessor != null) {
inputPreProcessors = new HashMap<>(inputPreProcessors);
inputPreProcessors.put(i, inputPreProcessor); inputPreProcessors.put(i, inputPreProcessor);
} }
} }
@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
obj.getClass().getSimpleName()); obj.getClass().getSimpleName());
} }
}); });
// make sure the indexes are sequenced properly
AtomicInteger i = new AtomicInteger();
ret.forEach(obj -> {
obj.setIndex(i.getAndIncrement());
});
return ret; return ret;
} }

View File

@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
throw new IllegalStateException( throw new IllegalStateException(
"Invalid input for Convolution layer (layer name=\"" "Invalid input for Convolution layer (layer name=\""
+ getName() + getName()
+ "\"): Expected CNN input, got " + "\" at index '"+getIndex()+"') : Expected CNN input, got "
+ inputType); + inputType);
} }
@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param kernelSize kernel size * @param kernelSize kernel size
*/ */
public B kernelSize(int... kernelSize) { public B kernelSize(int... kernelSize) {
this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize"); //this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
this.kernelSize$value = kernelSize;
this.kernelSize$set = true; this.kernelSize$set = true;
return self(); return self();
} }
@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param stride kernel size * @param stride kernel size
*/ */
public B stride(int... stride) { public B stride(int... stride) {
this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride"); //this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
this.stride$value = stride;
this.stride$set = true; this.stride$set = true;
return self(); return self();
} }
@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param padding kernel size * @param padding kernel size
*/ */
public B padding(int... padding) { public B padding(int... padding) {
this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding"); //this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
this.padding$value = padding;
this.padding$set = true; this.padding$set = true;
return self(); return self();
} }
@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param dilation kernel size * @param dilation kernel size
*/ */
public B dilation(int... dilation) { public B dilation(int... dilation) {
this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation"); //this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
this.dilation$value = dilation;
this.dilation$set = true; this.dilation$set = true;
return self(); return self();
} }

View File

@ -20,14 +20,19 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
import java.util.stream.IntStream;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import lombok.extern.jackson.Jacksonized; import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer; import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer {
boolean initializeParams, boolean initializeParams,
DataType networkDataType) { DataType networkDataType) {
setNetConfiguration(conf); setNetConfiguration(conf);
LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance(); runInheritance();
@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer {
getName(), getName(),
Deconvolution2DLayer.class); Deconvolution2DLayer.class);
} }
@Slf4j
private static final class Deconvolution2DBuilderImpl private static final class Deconvolution2DBuilderImpl
extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> { extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
public Deconvolution2D build() { public Deconvolution2D build() {
Deconvolution2D l = new Deconvolution2D(this); Deconvolution2D l = new Deconvolution2D(this);
if( l.getConvolutionMode() == ConvolutionMode.Same
&& IntStream.of(l.getPadding()).sum() != 0) {
log.warn("Invalid input for layer '{}'. "
+ "You cannot have a padding of {} when Convolution Mode is set to 'Same'."
+ " Padding will be ignored."
, l.getName(), l.getPadding());
}
/* strides * (input_size-1) + kernel_size - 2*padding */
//TODO: This is wrong, also depends on convolutionMode, etc ...
/*l.nOut = l.getStride()[0] * (l.getNIn()-1)
+ IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b)
- 2L * IntStream.of(l.getPadding()).sum();
*/
//l.nOut =264;
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }

View File

@ -62,6 +62,7 @@ public abstract class LayerConfiguration
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
@Getter @Setter protected String name; @Getter @Setter protected String name;
@Getter @Setter private int index;
@Getter @Setter protected List<LayerConstraint> allParamConstraints; @Getter @Setter protected List<LayerConstraint> allParamConstraints;
@Getter @Setter protected List<LayerConstraint> weightConstraints; @Getter @Setter protected List<LayerConstraint> weightConstraints;
@Getter @Setter protected List<LayerConstraint> biasConstraints; @Getter @Setter protected List<LayerConstraint> biasConstraints;
@ -72,6 +73,7 @@ public abstract class LayerConfiguration
/** The type of the layer, basically defines the base class and its properties */ /** The type of the layer, basically defines the base class and its properties */
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN; @Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
/** /**
* Number of parameters this layer has a result of its configuration * Number of parameters this layer has a result of its configuration
* @return number or parameters * @return number or parameters
@ -80,7 +82,6 @@ public abstract class LayerConfiguration
return initializer().numParams(this); return initializer().numParams(this);
} }
/** /**
* A reference to the neural net configuration. This field is excluded from json serialization as * A reference to the neural net configuration. This field is excluded from json serialization as
* well as from equals check to avoid circular referenced. * well as from equals check to avoid circular referenced.

View File

@ -37,122 +37,166 @@ import org.nd4j.linalg.api.shape.Shape;
@Data @Data
@EqualsAndHashCode(exclude = {"shape"}) @EqualsAndHashCode(exclude = {"shape"})
public class FeedForwardToCnnPreProcessor implements InputPreProcessor { public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
private long inputHeight; private long inputHeight;
private long inputWidth; private long inputWidth;
private long numChannels; private long numChannels;
@Getter(AccessLevel.NONE) @Getter(AccessLevel.NONE)
@Setter(AccessLevel.NONE) @Setter(AccessLevel.NONE)
private long[] shape; private long[] shape;
/** /**
* Reshape to a channels x rows x columns tensor * Reshape to a channels x rows x columns tensor
* *
* @param inputHeight the columns * @param inputHeight the columns
* @param inputWidth the rows * @param inputWidth the rows
* @param numChannels the channels * @param numChannels the channels
*/ */
@JsonCreator @JsonCreator
public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight, public FeedForwardToCnnPreProcessor(
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { @JsonProperty("inputHeight") long inputHeight,
this.inputHeight = inputHeight; @JsonProperty("inputWidth") long inputWidth,
this.inputWidth = inputWidth; @JsonProperty("numChannels") long numChannels) {
this.numChannels = numChannels; this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = numChannels;
}
/**
* Reshape to a channels x rows x columns tensor
*
* @param inputHeight the columns
* @param inputWidth the rows
*/
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = 1;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
this.shape = input.shape();
if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
if (input.columns() != inputWidth * inputHeight * numChannels)
throw new IllegalArgumentException(
"Invalid input: expect output columns must be equal to rows "
+ inputHeight
+ " x columns "
+ inputWidth
+ " x channels "
+ numChannels
+ " but was instead "
+ Arrays.toString(input.shape()));
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
return workspaceMgr.leverageTo(
ArrayType.ACTIVATIONS,
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
}
@Override
// return 4 dimensions
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
if (epsilons.rank() == 2)
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); // should never happen
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
} }
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) { return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
this.inputHeight = inputHeight; }
this.inputWidth = inputWidth;
this.numChannels = 1; @Override
public FeedForwardToCnnPreProcessor clone() {
try {
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
if (clone.shape != null) clone.shape = clone.shape.clone();
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
} }
}
@Override @Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public InputType getOutputType(InputType inputType) {
this.shape = input.shape();
if (input.rank() == 4)
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
if (input.columns() != inputWidth * inputHeight * numChannels) switch (inputType.getType()) {
throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows " case FF:
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
+ " but was instead " + Arrays.toString(input.shape())); val expSize = inputHeight * inputWidth * numChannels;
if (c.getSize() != expSize) {
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) throw new IllegalStateException(
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); "Invalid input: expected FeedForward input of size "
+ expSize
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, + " = (d="
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth)); + numChannels
} + " * w="
+ inputWidth
@Override + " * h="
// return 4 dimensions + inputHeight
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + "), got "
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons)) + inputType);
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
if (epsilons.rank() == 2)
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); //should never happen
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
} }
return InputType.convolutional(inputHeight, inputWidth, numChannels);
case CNN:
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape)); if (c2.getChannels() != numChannels
} || c2.getHeight() != inputHeight
|| c2.getWidth() != inputWidth) {
throw new IllegalStateException(
@Override "Invalid input: Got CNN input type with (d,w,h)=("
public FeedForwardToCnnPreProcessor clone() { + c2.getChannels()
try { + ","
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone(); + c2.getWidth()
if (clone.shape != null) + ","
clone.shape = clone.shape.clone(); + c2.getHeight()
return clone; + ") but expected ("
} catch (CloneNotSupportedException e) { + numChannels
throw new RuntimeException(e); + ","
+ inputHeight
+ ","
+ inputWidth
+ ")");
} }
} return c2;
case CNNFlat:
@Override InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
public InputType getOutputType(InputType inputType) { if (c3.getDepth() != numChannels
|| c3.getHeight() != inputHeight
switch (inputType.getType()) { || c3.getWidth() != inputWidth) {
case FF: throw new IllegalStateException(
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType; "Invalid input: Got CNN input type with (d,w,h)=("
val expSize = inputHeight * inputWidth * numChannels; + c3.getDepth()
if (c.getSize() != expSize) { + ","
throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize + c3.getWidth()
+ " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got " + ","
+ inputType); + c3.getHeight()
} + ") but expected ("
return InputType.convolutional(inputHeight, inputWidth, numChannels); + numChannels
case CNN: + ","
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType; + inputHeight
+ ","
if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) { + inputWidth
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels() + ")");
+ "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
+ "," + inputHeight + "," + inputWidth + ")");
}
return c2;
case CNNFlat:
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
+ "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
+ "," + inputHeight + "," + inputWidth + ")");
}
return c3.getUnflattenedType();
default:
throw new IllegalStateException("Invalid input type: got " + inputType);
} }
return c3.getUnflattenedType();
default:
throw new IllegalStateException("Invalid input type: got " + inputType);
} }
}
@Override @Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, public Pair<INDArray, MaskState> feedForwardMaskArray(
int minibatchSize) { INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
//Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example) // Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
return new Pair<>(maskArray, currentMaskState); return new Pair<>(maskArray, currentMaskState);
} }
} }

View File

@ -369,7 +369,7 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
protected String layerId() { protected String layerId() {
String name = this.layerConfiguration.getName(); String name = this.layerConfiguration.getName();
return "(layer name: " return "(network: " + getNetConfiguration().getName() + " layer name: "
+ (name == null ? "\"\"" : name) + (name == null ? "\"\"" : name)
+ ", layer index: " + ", layer index: "
+ index + index

View File

@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int[] args = new int[] { int[] args = new int[] {
(int)kH, (int)kW, strides[0], strides[1], (int)kH, (int)kW, strides[0], strides[1],
pad[0], pad[1], dilation[0], dilation[1], sameMode, pad[0], pad[1], dilation[0], dilation[1], sameMode //,
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
}; };
INDArray delta; INDArray delta;
@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int[] args = new int[] { int[] args = new int[] {
kH, kW, strides[0], strides[1], kH, kW, strides[0], strides[1],
pad[0], pad[1], dilation[0], dilation[1], sameMode, pad[0], pad[1], dilation[0], dilation[1], sameMode //,
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
}; };
//DL4J Deconv weights: [inputDepth, outputDepth, kH, kW] //DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
} else { } else {
opInputs = new INDArray[]{input, weights}; opInputs = new INDArray[]{input, weights};
} }
/**
* 2D deconvolution implementation
*
* IntArgs:
* 0: kernel height
* 1: kernel width
* 2: stride height
* 3: stride width
* 4: padding height
* 5: padding width
* 6: dilation height
* 7: dilation width
* 8: same mode: 0 false, 1 true
*/
CustomOp op = DynamicCustomOp.builder("deconv2d") CustomOp op = DynamicCustomOp.builder("deconv2d")
.addInputs(opInputs) .addInputs(opInputs)
.addIntegerArguments(args) .addIntegerArguments(args)

View File

@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i); LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
layers[i] = layers[i] =
lc.instantiate( lc.instantiate(
lc.getNetConfiguration(), this.getNetConfiguration(),
trainingListeners, trainingListeners,
i, i,
paramsView, paramsView,

View File

@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
params.put(GAMMA, createGamma(conf, gammaView, initializeParams)); params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(GAMMA); conf.getNetConfiguration().addNetWideVariable(GAMMA);
conf.addVariable(GAMMA);
params.put(BETA, createBeta(conf, betaView, initializeParams)); params.put(BETA, createBeta(conf, betaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(BETA); conf.getNetConfiguration().addNetWideVariable(BETA);
conf.addVariable(BETA);
meanOffset = 2 * nOut; meanOffset = 2 * nOut;
} }
@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
params.put(GLOBAL_MEAN, globalMeanView); params.put(GLOBAL_MEAN, globalMeanView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN); conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
conf.addVariable(GLOBAL_MEAN);
if(layer.isUseLogStd()){ if(layer.isUseLogStd()){
params.put(GLOBAL_LOG_STD, globalVarView); params.put(GLOBAL_LOG_STD, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD); conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
conf.addVariable(GLOBAL_LOG_STD);
} else { } else {
params.put(GLOBAL_VAR, globalVarView); params.put(GLOBAL_VAR, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR); conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
conf.addVariable(GLOBAL_VAR);
} }
return params; return params;

View File

@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); conf.addVariable(WEIGHT_KEY);
conf.addVariable(BIAS_KEY);
} else { } else {
INDArray weightView = paramsView; INDArray weightView = paramsView;
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
conf.addVariable(WEIGHT_KEY);
} }
return params; return params;

View File

@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise #for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true
# When configured, Gradle will run in incubating parallel mode. # When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit # This option should only be used with decoupled projects. More details, visit

Binary file not shown.

170
gradlew vendored
View File

@ -69,18 +69,18 @@ app_path=$0
# Need this for daisy-chained symlinks. # Need this for daisy-chained symlinks.
while while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ] [ -h "$app_path" ]
do do
ls=$(ls -ld "$app_path") ls=$( ls -ld "$app_path" )
link=${ls#*' -> '} link=${ls#*' -> '}
case $link in #( case $link in #(
/*) app_path=$link ;; #( /*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;; *) app_path=$APP_HOME$link ;;
esac esac
done done
APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
APP_NAME="Gradle" APP_NAME="Gradle"
APP_BASE_NAME=${0##*/} APP_BASE_NAME=${0##*/}
@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value. # Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum MAX_FD=maximum
warn() { warn () {
echo "$*" echo "$*"
} >&2 } >&2
die() { die () {
echo echo
echo "$*" echo "$*"
echo echo
exit 1 exit 1
} >&2 } >&2
# OS specific support (must be 'true' or 'false'). # OS specific support (must be 'true' or 'false').
@ -107,52 +107,51 @@ cygwin=false
msys=false msys=false
darwin=false darwin=false
nonstop=false nonstop=false
case "$(uname)" in #( case "$( uname )" in #(
CYGWIN*) cygwin=true ;; #( CYGWIN* ) cygwin=true ;; #(
Darwin*) darwin=true ;; #( Darwin* ) darwin=true ;; #(
MSYS* | MINGW*) msys=true ;; #( MSYS* | MINGW* ) msys=true ;; #(
NONSTOP*) nonstop=true ;; NONSTOP* ) nonstop=true ;;
esac esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM. # Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ]; then if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ]; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables # IBM's JDK on AIX uses strange locations for the executables
JAVACMD=$JAVA_HOME/jre/sh/java JAVACMD=$JAVA_HOME/jre/sh/java
else else
JAVACMD=$JAVA_HOME/bin/java JAVACMD=$JAVA_HOME/bin/java
fi fi
if [ ! -x "$JAVACMD" ]; then if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the Please set the JAVA_HOME variable in your environment to match the
location of your Java installation." location of your Java installation."
fi fi
else else
JAVACMD=java JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the Please set the JAVA_HOME variable in your environment to match the
location of your Java installation." location of your Java installation."
fi fi
# Increase the maximum file descriptors if we can. # Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #( case $MAX_FD in #(
max*) max*)
MAX_FD=$(ulimit -H -n) || MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit" warn "Could not query maximum file descriptor limit"
;; esac
esac case $MAX_FD in #(
case $MAX_FD in #( '' | soft) :;; #(
'' | soft) : ;; #( *)
*) ulimit -n "$MAX_FD" ||
ulimit -n "$MAX_FD" || warn "Could not set maximum file descriptor limit to $MAX_FD"
warn "Could not set maximum file descriptor limit to $MAX_FD" esac
;;
esac
fi fi
# Collect all arguments for the java command, stacking in reverse order: # Collect all arguments for the java command, stacking in reverse order:
@ -164,36 +163,34 @@ fi
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# For Cygwin or MSYS, switch paths to Windows format before running java # For Cygwin or MSYS, switch paths to Windows format before running java
if "$cygwin" || "$msys"; then if "$cygwin" || "$msys" ; then
APP_HOME=$(cygpath --path --mixed "$APP_HOME") APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$(cygpath --path --mixed "$CLASSPATH") CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$(cygpath --unix "$JAVACMD") JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh # Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg; do for arg do
if if
case $arg in #( case $arg in #(
-*) false ;; # don't mess with options #( -*) false ;; # don't mess with options #(
/?*) /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath [ -e "$t" ] ;; #(
[ -e "$t" ] *) false ;;
;; #( esac
*) false ;; then
esac arg=$( cygpath --path --ignore --mixed "$arg" )
then fi
arg=$(cygpath --path --ignore --mixed "$arg") # Roll the args list around exactly as many times as the number of
fi # args, so each arg winds up back in the position where it started, but
# Roll the args list around exactly as many times as the number of # possibly modified.
# args, so each arg winds up back in the position where it started, but #
# possibly modified. # NB: a `for` loop captures its iteration list before it begins, so
# # changing the positional parameters here affects neither the number of
# NB: a `for` loop captures its iteration list before it begins, so # iterations, nor the values presented in `arg`.
# changing the positional parameters here affects neither the number of shift # remove old arg
# iterations, nor the values presented in `arg`. set -- "$@" "$arg" # push replacement arg
shift # remove old arg done
set -- "$@" "$arg" # push replacement arg
done
fi fi
# Collect all arguments for the java command; # Collect all arguments for the java command;
@ -203,15 +200,10 @@ fi
# * put everything else in single quotes, so that it's not re-expanded. # * put everything else in single quotes, so that it's not re-expanded.
set -- \ set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \ "-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \ -classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \ org.gradle.wrapper.GradleWrapperMain \
"$@" "$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1; then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args. # Use "xargs" to parse quoted args.
# #
@ -233,10 +225,10 @@ fi
# #
eval "set -- $( eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 | xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' ' tr '\n' ' '
)" '"$@"' )" '"$@"'
exec "$JAVACMD" "$@" exec "$JAVACMD" "$@"

14
gradlew.bat vendored
View File

@ -14,7 +14,7 @@
@rem limitations under the License. @rem limitations under the License.
@rem @rem
@if "%DEBUG%"=="" @echo off @if "%DEBUG%" == "" @echo off
@rem ########################################################################## @rem ##########################################################################
@rem @rem
@rem Gradle startup script for Windows @rem Gradle startup script for Windows
@ -25,7 +25,7 @@
if "%OS%"=="Windows_NT" setlocal if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0 set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=. if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0 set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME% set APP_HOME=%DIRNAME%
@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1 %JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute if "%ERRORLEVEL%" == "0" goto execute
echo. echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
:end :end
@rem End local scope for the variables with windows NT shell @rem End local scope for the variables with windows NT shell
if %ERRORLEVEL% equ 0 goto mainEnd if "%ERRORLEVEL%"=="0" goto mainEnd
:fail :fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code! rem the _cmd.exe /c_ return code!
set EXIT_CODE=%ERRORLEVEL% if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
if %EXIT_CODE% equ 0 set EXIT_CODE=1 exit /b 1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd :mainEnd
if "%OS%"=="Windows_NT" endlocal if "%OS%"=="Windows_NT" endlocal