gan example

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-08-07 10:32:39 +02:00
parent 3ea555b645
commit 1c3496ad84
20 changed files with 1267 additions and 838 deletions

View File

@ -1,473 +1,261 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import static net.brutex.ai.dnn.api.NN.dense;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.UUID;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.WindowConstants;
import lombok.extern.slf4j.Slf4j;
import javax.swing.*;
import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
@Slf4j
public class App {
private static final double LEARNING_RATE = 0.000002;
private static final double GRADIENT_THRESHOLD = 100.0;
private static final double LEARNING_RATE = 0.002;
private static final double GRADIENT_THRESHOLD = 100.0;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
private static final int BATCHSIZE = 128;
private static JFrame frame;
private static JPanel panel;
private static final int X_DIM = 20 ;
private static final int Y_DIM = 20;
private static final int CHANNELS = 1;
private static final int batchSize = 1;
private static final int INPUT = 10;
private static LayerConfiguration[] genLayers() {
return new LayerConfiguration[] {
dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(256).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
};
}
private static final int OUTPUT_PER_PANEL = 16;
/**
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
*
* @return config
*/
private static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(genLayers())
.name("generator")
.build();
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
return conf;
}
private static JFrame frame;
private static JFrame frame2;
private static JPanel panel;
private static JPanel panel2;
private static LayerConfiguration[] disLayers() {
return new LayerConfiguration[]{
dense().nIn(784).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(1024).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(512).nOut(256).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
};
}
private static final String OUTPUT_DIR = "C:/temp/output/";
private static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(disLayers())
.name("discriminator")
.build();
private static LayerConfiguration[] genLayers() {
return new LayerConfiguration[] {
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(Activation.LEAKYRELU).build(),
return conf;
}
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
private static NeuralNetConfiguration gan() {
LayerConfiguration[] genLayers = genLayers();
LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
.map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
};
}
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(layers)
.name("GAN")
.build();
/**
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
*
* @return config
*/
private static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInit(WeightInit.XAVIER)
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(genLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
.build();
((NeuralNetConfiguration) conf).init();
return conf;
}
return conf;
}
@Test
public void runTest() throws Exception {
App.main(null);
}
public static void main(String... args) throws Exception {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
private static LayerConfiguration[] disLayers() {
return new LayerConfiguration[]{
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
};
}
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init();
dis.init();
gan.init();
private static NeuralNetConfiguration discriminator() {
copyParams(gen, dis, gan);
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.seed(42)
.updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(disLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build();
((NeuralNetConfiguration) conf).init();
gen.addTrainingListeners(new PerformanceListener(10, true));
dis.addTrainingListeners(new PerformanceListener(10, true));
gan.addTrainingListeners(new PerformanceListener(10, true));
return conf;
}
trainData.reset();
private static NeuralNetConfiguration gan() {
LayerConfiguration[] genLayers = genLayers();
LayerConfiguration[] disLayers = Arrays.stream(disLayers())
.map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
int j = 0;
for (int i = 0; i < 50; i++) {
while (trainData.hasNext()) {
j++;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42)
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold( 100 )
//.weightInitFn( new WeightInitXavier() ) //this is internal
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightInit( WeightInit.XAVIER)
//.activationFn( new ActivationIdentity()) //this is internal
.activation( Activation.IDENTITY )
.layersFromArray( layers )
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.dataType(DataType.FLOAT)
.build();
((NeuralNetConfiguration) conf).init();
return conf;
}
// generate data
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, 100);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
@Test
public void runTest() throws Exception {
if(! log.isDebugEnabled()) {
log.info("Logging is not set to DEBUG");
}
else {
log.info("Logging is set to DEBUG");
}
main();
}
if (j % 10 == 1) {
System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
INDArray[] samples = new INDArray[9];
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
public static void main(String... args) throws Exception {
for (int k = 0; k < 9; k++) {
INDArray input = fakeSet2.get(k).getFeatures();
//samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
Nd4j.getMemoryManager().setAutoGcWindow(500);
//MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
//FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
ImageTransform tr = new PipelineImageTransform.Builder()
//.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init(); log.debug("Generator network: {}", gen);
dis.init(); log.debug("Discriminator network: {}", dis);
gan.init(); log.info("Complete GAN network: {}", gan);
copyParams(gen, dis, gan);
//gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis"));
//System.out.println(gan.toString());
//gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
//trainData.reset();
int j = 0;
for (int i = 0; i < 51; i++) { //epoch
while (trainData.hasNext()) {
j++;
DataSet next = trainData.next();
// generate data
INDArray real = next.getFeatures();//.div(255f);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
visualize(samples);
}
}
trainData.reset();
// Copy the GANs generator to gen.
//updateGen(gen, gan);
}
//if(i%20 == 0) {
frame2 = visualize(new INDArray[]{real}, batchSize,
frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
//}
real.divi(255f);
// int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
//log.info("real has {} items.", real.length());
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
//dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
//Visualize and reporting
if (j % 10 == 1) {
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
for (int k = 0; k < samples.length; k++) {
//INDArray input = fakeSet2.get(k).getFeatures();
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
INDArray input = fakeSet2.get(k).getFeatures();
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
//samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
//samples[k] =
samples[k].addi(1f).divi(2f).muli(255f);
}
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
}
}
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
// Copy the GANs generator to gen.
updateGen(gen, gan);
gen.save(new File("mnist-mlp-generator.dlj"));
}
}
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) {
if(gan.getLayer(i).getParams() != null)
gan.getLayer(i).setParams(gen.getLayer(i).getParams());
} else {
if(gan.getLayer(i).getParams() != null)
gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
}
}
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
if (isOrig) {
frame.setTitle("Viz Original");
} else {
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8));
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true);
frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
return frame;
}
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS>1) {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = X_DIM * Y_DIM;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
for (int x = 0; x < X_DIM; x++) { //step through the rows y
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
pxl++; //next item in INDArray
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
} else {
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
}
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled);
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
private static void visualize(INDArray[] samples) {
if (frame == null) {
frame = new JFrame();
frame.setTitle("Viz");
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
panel = new JPanel();
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
panel.removeAll();
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
for (INDArray sample : samples) {
panel.add(getImage(sample));
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
frame.revalidate();
frame.pack();
}
private static JLabel getImage(INDArray tensor) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled);
}
}

View File

@ -0,0 +1,343 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.List;
import javax.imageio.ImageIO;
import javax.swing.*;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.*;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
@Slf4j
public class App2 {
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
static final float COLORSPACE = 255f;
static final int DIMENSIONS = 28;
static final int CHANNELS = 1;
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
final int OUTPUT_PER_PANEL = 10;
final boolean BIAS = true;
static final int BATCHSIZE=128;
private JFrame frame2, frame;
static final String OUTPUT_DIR = "d:/out/";
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
@Test
void runTest() throws IOException {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
ImageTransform tr = new PipelineImageTransform.Builder()
.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
.map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
NeuralNetConfiguration netConfiguration =
NeuralNetConfiguration.builder()
.name("GAN")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.updater(App2Config.UPDATER)
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
dis.init(); log.debug("Discriminator network: {}", dis);
gen.init(); log.debug("Generator network: {}", gen);
gan.init(); log.debug("GAN network: {}", gan);
log.info("Generator Summary:\n{}", gen.summary());
log.info("GAN Summary:\n{}", gan.summary());
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
int j = 0;
for (int i = 0; i < 51; i++) { //epoch
while (trainData.hasNext()) {
j++;
DataSet next = trainData.next();
// generate data
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
//if(i%20 == 0) {
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
//}
//real.divi(255f);
// int batchSize = (int) real.shape()[0];
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
// when generator has TANH as activation - value range is -1 to 1
// when generator has SIGMOID, then range is 0 to 1
fake.addi(1f).divi(2f);
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
//Visualize and reporting
if (j % 10 == 1) {
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
for (int k = 0; k < samples.length; k++) {
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
INDArray input = fakeSet2.get(k).getFeatures();
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
input = input.reshape(1, App2Config.INPUT);
//samples[k] = gen.output(input, false);
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
//samples[k] =
//samples[k].muli(255f);
}
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
}
}
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
// Copy the GANs generator to gen.
updateGen(gen, gan);
log.info("Updated GAN's generator from gen.");
gen.save(new File("mnist-mlp-generator.dlj"));
}
}
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
if (isOrig) {
frame.setTitle("Viz Original");
} else {
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8));
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true);
frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
return frame;
}
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS >1) {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = DIMENSIONS * DIMENSIONS;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
bi.getRaster().setSample(x, y, c, f_pxl);
pxl++; //next item in INDArray
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled);
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage;
if (CHANNELS > 1) {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
} else {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
}
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
}

View File

@ -0,0 +1,176 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import static net.brutex.ai.dnn.api.NN.*;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class App2Config {
public static final int INPUT = 100;
public static final int X_DIM = 28;
public static final int y_DIM = 28;
public static final int CHANNELS = 1;
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
static LayerConfiguration[] genLayerConfig() {
return new LayerConfiguration[] {
/*
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
ActivationLayer.builder().activation(Activation.RELU).build(), /*
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(1,1)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().nOut(CHANNELS).build(),
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(2,2)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
// DropoutLayer.builder(0.001).build(),
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(256).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
};
}
static LayerConfiguration[] disLayerConfig() {
return new LayerConfiguration[] {/*
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
.build(),
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
.build(),
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
BatchNormalization.builder().build(),
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SIGMOID)
.build()
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
.activation(Activation.SIGMOID)
.build() */
dense().nIn(784).nOut(1024).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(1024).nOut(512).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(512).nOut(256).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
};
}
static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("generator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER)
.weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(App2Config.genLayerConfig())
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
.build();
conf.init();
return conf;
}
static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("discriminator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER)
.weightInit(WeightInit.XAVIER)
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(disLayerConfig())
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
conf.init();
return conf;
}
}

View File

@ -43,223 +43,255 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.impo
@Slf4j
public class KerasSequentialModel extends KerasModel {
/**
* (Recommended) Builder-pattern constructor for Sequential model.
*
* @param modelBuilder builder object
* @throws IOException I/O exception
* @throws InvalidKerasConfigurationException Invalid Keras configuration
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
*/
public KerasSequentialModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
/**
* (Recommended) Builder-pattern constructor for Sequential model.
*
* @param modelBuilder builder object
* @throws IOException I/O exception
* @throws InvalidKerasConfigurationException Invalid Keras configuration
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
*/
public KerasSequentialModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException,
IOException,
InvalidKerasConfigurationException {
this(
modelBuilder.getModelJson(),
modelBuilder.getModelYaml(),
modelBuilder.getWeightsArchive(),
modelBuilder.getWeightsRoot(),
modelBuilder.getTrainingJson(),
modelBuilder.getTrainingArchive(),
modelBuilder.isEnforceTrainingConfig(),
modelBuilder.getInputShape());
}
/**
* (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
* training configuration (JSON), weights, and "training mode" boolean indicator. When built in
* training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
* Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
* otherwise ignored.
*
* @param modelJson model configuration JSON string
* @param modelYaml model configuration YAML string
* @param trainingJson training configuration JSON string
* @throws IOException I/O exception
*/
public KerasSequentialModel(
String modelJson,
String modelYaml,
Hdf5Archive weightsArchive,
String weightsRoot,
String trainingJson,
Hdf5Archive trainingArchive,
boolean enforceTrainingConfig,
int[] inputShape)
throws IOException,
InvalidKerasConfigurationException,
UnsupportedKerasConfigurationException {
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
this.enforceTrainingConfig = enforceTrainingConfig;
/* Determine model configuration type. */
if (!modelConfig.containsKey(config.getFieldClassName()))
throw new InvalidKerasConfigurationException(
"Could not determine Keras model class (no "
+ config.getFieldClassName()
+ " field found)");
this.className = (String) modelConfig.get(config.getFieldClassName());
if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException(
"Model class name must be "
+ config.getFieldClassNameSequential()
+ " (found "
+ this.className
+ ")");
/* Process layer configurations. */
if (!modelConfig.containsKey(config.getModelFieldConfig()))
throw new InvalidKerasConfigurationException(
"Could not find layer configurations (no "
+ config.getModelFieldConfig()
+ " field found)");
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
// For consistency
// "config" is now an object containing a "name" and "layers", the latter contain the same data
// as before.
// This change only affects Sequential models.
List<Object> layerList;
try {
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
} catch (Exception e) {
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
layerList = (List<Object>) layerMap.get("layers");
}
/**
* (Not recommended) Constructor for Sequential model from model configuration
* (JSON or YAML), training configuration (JSON), weights, and "training mode"
* boolean indicator. When built in training mode, certain unsupported configurations
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
* will generate warnings but will be otherwise ignored.
*
* @param modelJson model configuration JSON string
* @param modelYaml model configuration YAML string
* @param trainingJson training configuration JSON string
* @throws IOException I/O exception
*/
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
int[] inputShape)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
this.layers = layerPair.getFirst();
this.layersOrdered = layerPair.getSecond();
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
this.enforceTrainingConfig = enforceTrainingConfig;
KerasLayer inputLayer;
if (this.layersOrdered.get(0) instanceof KerasInput) {
inputLayer = this.layersOrdered.get(0);
} else {
/* Add placeholder input layer and update lists of input and output layers. */
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
Preconditions.checkState(
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
inputLayer = new KerasInput("input1", firstLayerInputShape);
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
this.layers.put(inputLayer.getName(), inputLayer);
this.layersOrdered.add(0, inputLayer);
}
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
this.outputLayerNames =
new ArrayList<>(
Collections.singletonList(
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
/* Determine model configuration type. */
if (!modelConfig.containsKey(config.getFieldClassName()))
throw new InvalidKerasConfigurationException(
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
this.className = (String) modelConfig.get(config.getFieldClassName());
if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
+ " (found " + this.className + ")");
/* Process layer configurations. */
if (!modelConfig.containsKey(config.getModelFieldConfig()))
throw new InvalidKerasConfigurationException(
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
// This change only affects Sequential models.
List<Object> layerList;
try {
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
} catch (Exception e) {
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
layerList = (List<Object>) layerMap.get("layers");
}
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
prepareLayers(layerList);
this.layers = layerPair.getFirst();
this.layersOrdered = layerPair.getSecond();
KerasLayer inputLayer;
if (this.layersOrdered.get(0) instanceof KerasInput) {
inputLayer = this.layersOrdered.get(0);
} else {
/* Add placeholder input layer and update lists of input and output layers. */
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
inputLayer = new KerasInput("input1", firstLayerInputShape);
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
this.layers.put(inputLayer.getName(), inputLayer);
this.layersOrdered.add(0, inputLayer);
}
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
this.outputLayerNames = new ArrayList<>(
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
/* Update each layer's inbound layer list to include (only) previous layer. */
KerasLayer prevLayer = null;
for (KerasLayer layer : this.layersOrdered) {
if (prevLayer != null)
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
prevLayer = layer;
}
/* Import training configuration. */
if (enforceTrainingConfig) {
if (trainingJson != null)
importTrainingConfiguration(trainingJson);
else log.warn("If enforceTrainingConfig is true, a training " +
"configuration object has to be provided. Usually the only practical way to do this is to store" +
" your keras model with `model.save('model_path.h5'. If you store model config and weights" +
" separately no training configuration is attached.");
}
this.outputTypes = inferOutputTypes(inputShape);
if (weightsArchive != null)
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
/* Update each layer's inbound layer list to include (only) previous layer. */
KerasLayer prevLayer = null;
for (KerasLayer layer : this.layersOrdered) {
if (prevLayer != null)
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
prevLayer = layer;
}
/**
* Default constructor
*/
public KerasSequentialModel() {
super();
/* Import training configuration. */
if (enforceTrainingConfig) {
if (trainingJson != null) importTrainingConfiguration(trainingJson);
else
log.warn(
"If enforceTrainingConfig is true, a training "
+ "configuration object has to be provided. Usually the only practical way to do this is to store"
+ " your keras model with `model.save('model_path.h5'. If you store model config and weights"
+ " separately no training configuration is attached.");
}
/**
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
*
* @return NeuralNetConfiguration
*/
public NeuralNetConfiguration getNeuralNetConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException(
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
if (this.inputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
if (this.outputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
this.outputTypes = inferOutputTypes(inputShape);
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
if (weightsArchive != null)
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
}
if (optimizer != null) {
modelBuilder.updater(optimizer);
/** Default constructor */
public KerasSequentialModel() {
super();
}
/**
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
*
* @return NeuralNetConfiguration
*/
public NeuralNetConfiguration getNeuralNetConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException(
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
if (this.inputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
if (this.outputLayerNames.size() != 1)
throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
NeuralNetConfiguration.builder();
if (optimizer != null) {
modelBuilder.updater(optimizer);
}
// don't forcibly override for keras import
modelBuilder.overrideNinUponBuild(false);
/* Add layers one at a time. */
KerasLayer prevLayer = null;
int layerIndex = 0;
for (KerasLayer layer : this.layersOrdered) {
if (layer.isLayer()) {
int nbInbound = layer.getInboundLayerNames().size();
if (nbInbound != 1)
throw new InvalidKerasConfigurationException(
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
+ nbInbound
+ " for layer "
+ layer.getName()
+ ")");
if (prevLayer != null) {
InputType[] inputTypes = new InputType[1];
InputPreProcessor preprocessor;
if (prevLayer.isInputPreProcessor()) {
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
} else {
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
preprocessor = layer.getInputPreprocessor(inputTypes);
if (preprocessor != null) {
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
}
if (preprocessor != null) {
Map<Integer, InputPreProcessor> map = new HashMap<>();
map.put(layerIndex, preprocessor);
modelBuilder.inputPreProcessors(map);
}
}
//don't forcibly override for keras import
modelBuilder.overrideNinUponBuild(false);
/* Add layers one at a time. */
KerasLayer prevLayer = null;
int layerIndex = 0;
for (KerasLayer layer : this.layersOrdered) {
if (layer.isLayer()) {
int nbInbound = layer.getInboundLayerNames().size();
if (nbInbound != 1)
throw new InvalidKerasConfigurationException(
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
+ nbInbound + " for layer " + layer.getName() + ")");
if (prevLayer != null) {
InputType[] inputTypes = new InputType[1];
InputPreProcessor preprocessor;
if (prevLayer.isInputPreProcessor()) {
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
} else {
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
preprocessor = layer.getInputPreprocessor(inputTypes);
if(preprocessor != null) {
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
}
else
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
}
if (preprocessor != null)
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
}
modelBuilder.layer(layerIndex++, layer.getLayer());
} else if (layer.getVertex() != null)
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
+ layer.getClassName() + ", layer name " + layer.getName() + ")");
prevLayer = layer;
}
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
modelBuilder.backpropType(BackpropType.TruncatedBPTT)
.tbpttFwdLength(truncatedBPTT)
.tbpttBackLength(truncatedBPTT);
else
modelBuilder.backpropType(BackpropType.Standard);
NeuralNetConfiguration build = modelBuilder.build();
return build;
modelBuilder.layer(layerIndex++, layer.getLayer());
} else if (layer.getVertex() != null)
throw new InvalidKerasConfigurationException(
"Cannot add vertex to NeuralNetConfiguration (class name "
+ layer.getClassName()
+ ", layer name "
+ layer.getName()
+ ")");
prevLayer = layer;
}
/**
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
*
* @return MultiLayerNetwork
*/
public MultiLayerNetwork getMultiLayerNetwork()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
return getMultiLayerNetwork(true);
}
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
modelBuilder
.backpropType(BackpropType.TruncatedBPTT)
.tbpttFwdLength(truncatedBPTT)
.tbpttBackLength(truncatedBPTT);
else modelBuilder.backpropType(BackpropType.Standard);
/**
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
*
* @return MultiLayerNetwork
*/
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
model.init();
if (importWeights)
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
return model;
}
NeuralNetConfiguration build = modelBuilder.build();
return build;
}
/**
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
*
* @return MultiLayerNetwork
*/
public MultiLayerNetwork getMultiLayerNetwork()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
return getMultiLayerNetwork(true);
}
/**
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
*
* @return MultiLayerNetwork
*/
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
model.init();
if (importWeights)
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
return model;
}
}

View File

@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
/**
* A fluent API to configure and create artificial neural networks
@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB
public class NN {
public static NeuralNetConfigurationBuilder<?, ?> net() {
public static NeuralNetConfigurationBuilder<?, ?> nn() {
return NeuralNetConfiguration.builder();
}
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
}

View File

@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
@Getter @Setter @NonNull @lombok.Builder.Default
protected BackpropType backpropType = BackpropType.Standard;
@Getter @lombok.Builder.Default
@Getter @Setter @Singular
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
/**
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
* @param processor what to use to preProcess the data.
* @return builder pattern
*/
public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>();
inputPreProcessors$value.put(layer, processor);
inputPreProcessors$set = true;
return self();
}
//public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
// inputPreProcessors$value.put(layer, processor);
// inputPreProcessors$set = true;
// return self();
// }
/**
* Set layer at index

View File

@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import lombok.*;
import lombok.experimental.SuperBuilder;
@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@NonNull
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
if (inputPreProcessor != null) {
inputPreProcessors = new HashMap<>(inputPreProcessors);
inputPreProcessors.put(i, inputPreProcessor);
}
}
@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
obj.getClass().getSimpleName());
}
});
// make sure the indexes are sequenced properly
AtomicInteger i = new AtomicInteger();
ret.forEach(obj -> {
obj.setIndex(i.getAndIncrement());
});
return ret;
}

View File

@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
throw new IllegalStateException(
"Invalid input for Convolution layer (layer name=\""
+ getName()
+ "\"): Expected CNN input, got "
+ "\" at index '"+getIndex()+"') : Expected CNN input, got "
+ inputType);
}
@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param kernelSize kernel size
*/
public B kernelSize(int... kernelSize) {
this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
//this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
this.kernelSize$value = kernelSize;
this.kernelSize$set = true;
return self();
}
@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param stride kernel size
*/
public B stride(int... stride) {
this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
//this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
this.stride$value = stride;
this.stride$set = true;
return self();
}
@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param padding kernel size
*/
public B padding(int... padding) {
this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
//this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
this.padding$value = padding;
this.padding$set = true;
return self();
}
@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param dilation kernel size
*/
public B dilation(int... dilation) {
this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
//this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
this.dilation$value = dilation;
this.dilation$set = true;
return self();
}

View File

@ -20,14 +20,19 @@
package org.deeplearning4j.nn.conf.layers;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.stream.IntStream;
import lombok.*;
import lombok.experimental.SuperBuilder;
import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer {
boolean initializeParams,
DataType networkDataType) {
setNetConfiguration(conf);
LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance();
@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer {
getName(),
Deconvolution2DLayer.class);
}
@Slf4j
private static final class Deconvolution2DBuilderImpl
extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
public Deconvolution2D build() {
Deconvolution2D l = new Deconvolution2D(this);
if( l.getConvolutionMode() == ConvolutionMode.Same
&& IntStream.of(l.getPadding()).sum() != 0) {
log.warn("Invalid input for layer '{}'. "
+ "You cannot have a padding of {} when Convolution Mode is set to 'Same'."
+ " Padding will be ignored."
, l.getName(), l.getPadding());
}
/* strides * (input_size-1) + kernel_size - 2*padding */
//TODO: This is wrong, also depends on convolutionMode, etc ...
/*l.nOut = l.getStride()[0] * (l.getNIn()-1)
+ IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b)
- 2L * IntStream.of(l.getPadding()).sum();
*/
//l.nOut =264;
l.initializeConstraints();
return l;
}

View File

@ -62,6 +62,7 @@ public abstract class LayerConfiguration
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
@Getter @Setter protected String name;
@Getter @Setter private int index;
@Getter @Setter protected List<LayerConstraint> allParamConstraints;
@Getter @Setter protected List<LayerConstraint> weightConstraints;
@Getter @Setter protected List<LayerConstraint> biasConstraints;
@ -72,6 +73,7 @@ public abstract class LayerConfiguration
/** The type of the layer, basically defines the base class and its properties */
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
/**
* Number of parameters this layer has a result of its configuration
* @return number or parameters
@ -80,7 +82,6 @@ public abstract class LayerConfiguration
return initializer().numParams(this);
}
/**
* A reference to the neural net configuration. This field is excluded from json serialization as
* well as from equals check to avoid circular referenced.

View File

@ -37,122 +37,166 @@ import org.nd4j.linalg.api.shape.Shape;
@Data
@EqualsAndHashCode(exclude = {"shape"})
public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
private long inputHeight;
private long inputWidth;
private long numChannels;
private long inputHeight;
private long inputWidth;
private long numChannels;
@Getter(AccessLevel.NONE)
@Setter(AccessLevel.NONE)
private long[] shape;
@Getter(AccessLevel.NONE)
@Setter(AccessLevel.NONE)
private long[] shape;
/**
* Reshape to a channels x rows x columns tensor
*
* @param inputHeight the columns
* @param inputWidth the rows
* @param numChannels the channels
*/
@JsonCreator
public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = numChannels;
/**
* Reshape to a channels x rows x columns tensor
*
* @param inputHeight the columns
* @param inputWidth the rows
* @param numChannels the channels
*/
@JsonCreator
public FeedForwardToCnnPreProcessor(
@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = numChannels;
}
/**
* Reshape to a channels x rows x columns tensor
*
* @param inputHeight the columns
* @param inputWidth the rows
*/
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = 1;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
this.shape = input.shape();
if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
if (input.columns() != inputWidth * inputHeight * numChannels)
throw new IllegalArgumentException(
"Invalid input: expect output columns must be equal to rows "
+ inputHeight
+ " x columns "
+ inputWidth
+ " x channels "
+ numChannels
+ " but was instead "
+ Arrays.toString(input.shape()));
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
return workspaceMgr.leverageTo(
ArrayType.ACTIVATIONS,
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
}
@Override
// return 4 dimensions
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
if (epsilons.rank() == 2)
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); // should never happen
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
}
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = 1;
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
}
@Override
public FeedForwardToCnnPreProcessor clone() {
try {
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
if (clone.shape != null) clone.shape = clone.shape.clone();
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
this.shape = input.shape();
if (input.rank() == 4)
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
@Override
public InputType getOutputType(InputType inputType) {
if (input.columns() != inputWidth * inputHeight * numChannels)
throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows "
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels
+ " but was instead " + Arrays.toString(input.shape()));
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS,
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
}
@Override
// return 4 dimensions
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
if (epsilons.rank() == 2)
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); //should never happen
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
switch (inputType.getType()) {
case FF:
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
val expSize = inputHeight * inputWidth * numChannels;
if (c.getSize() != expSize) {
throw new IllegalStateException(
"Invalid input: expected FeedForward input of size "
+ expSize
+ " = (d="
+ numChannels
+ " * w="
+ inputWidth
+ " * h="
+ inputHeight
+ "), got "
+ inputType);
}
return InputType.convolutional(inputHeight, inputWidth, numChannels);
case CNN:
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
}
@Override
public FeedForwardToCnnPreProcessor clone() {
try {
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
if (clone.shape != null)
clone.shape = clone.shape.clone();
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
if (c2.getChannels() != numChannels
|| c2.getHeight() != inputHeight
|| c2.getWidth() != inputWidth) {
throw new IllegalStateException(
"Invalid input: Got CNN input type with (d,w,h)=("
+ c2.getChannels()
+ ","
+ c2.getWidth()
+ ","
+ c2.getHeight()
+ ") but expected ("
+ numChannels
+ ","
+ inputHeight
+ ","
+ inputWidth
+ ")");
}
}
@Override
public InputType getOutputType(InputType inputType) {
switch (inputType.getType()) {
case FF:
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
val expSize = inputHeight * inputWidth * numChannels;
if (c.getSize() != expSize) {
throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
+ " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
+ inputType);
}
return InputType.convolutional(inputHeight, inputWidth, numChannels);
case CNN:
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
+ "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
+ "," + inputHeight + "," + inputWidth + ")");
}
return c2;
case CNNFlat:
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
+ "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
+ "," + inputHeight + "," + inputWidth + ")");
}
return c3.getUnflattenedType();
default:
throw new IllegalStateException("Invalid input type: got " + inputType);
return c2;
case CNNFlat:
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
if (c3.getDepth() != numChannels
|| c3.getHeight() != inputHeight
|| c3.getWidth() != inputWidth) {
throw new IllegalStateException(
"Invalid input: Got CNN input type with (d,w,h)=("
+ c3.getDepth()
+ ","
+ c3.getWidth()
+ ","
+ c3.getHeight()
+ ") but expected ("
+ numChannels
+ ","
+ inputHeight
+ ","
+ inputWidth
+ ")");
}
return c3.getUnflattenedType();
default:
throw new IllegalStateException("Invalid input type: got " + inputType);
}
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
int minibatchSize) {
//Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
return new Pair<>(maskArray, currentMaskState);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(
INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
// Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
return new Pair<>(maskArray, currentMaskState);
}
}

View File

@ -369,7 +369,7 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
protected String layerId() {
String name = this.layerConfiguration.getName();
return "(layer name: "
return "(network: " + getNetConfiguration().getName() + " layer name: "
+ (name == null ? "\"\"" : name)
+ ", layer index: "
+ index

View File

@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int[] args = new int[] {
(int)kH, (int)kW, strides[0], strides[1],
pad[0], pad[1], dilation[0], dilation[1], sameMode,
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
pad[0], pad[1], dilation[0], dilation[1], sameMode //,
//nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
};
INDArray delta;
@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int[] args = new int[] {
kH, kW, strides[0], strides[1],
pad[0], pad[1], dilation[0], dilation[1], sameMode,
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
pad[0], pad[1], dilation[0], dilation[1], sameMode //,
//nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
};
//DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
} else {
opInputs = new INDArray[]{input, weights};
}
/**
* 2D deconvolution implementation
*
* IntArgs:
* 0: kernel height
* 1: kernel width
* 2: stride height
* 3: stride width
* 4: padding height
* 5: padding width
* 6: dilation height
* 7: dilation width
* 8: same mode: 0 false, 1 true
*/
CustomOp op = DynamicCustomOp.builder("deconv2d")
.addInputs(opInputs)
.addIntegerArguments(args)

View File

@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
layers[i] =
lc.instantiate(
lc.getNetConfiguration(),
this.getNetConfiguration(),
trainingListeners,
i,
paramsView,

View File

@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(GAMMA);
conf.addVariable(GAMMA);
params.put(BETA, createBeta(conf, betaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(BETA);
conf.addVariable(BETA);
meanOffset = 2 * nOut;
}
@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
params.put(GLOBAL_MEAN, globalMeanView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
conf.addVariable(GLOBAL_MEAN);
if(layer.isUseLogStd()){
params.put(GLOBAL_LOG_STD, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
conf.addVariable(GLOBAL_LOG_STD);
} else {
params.put(GLOBAL_VAR, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
conf.addVariable(GLOBAL_VAR);
}
return params;

View File

@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
conf.addVariable(WEIGHT_KEY);
conf.addVariable(BIAS_KEY);
} else {
INDArray weightView = paramsView;
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
conf.addVariable(WEIGHT_KEY);
}
return params;

View File

@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit

Binary file not shown.

170
gradlew vendored
View File

@ -69,18 +69,18 @@ app_path=$0
# Need this for daisy-chained symlinks.
while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
do
ls=$(ls -ld "$app_path")
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
ls=$( ls -ld "$app_path" )
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
done
APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
APP_NAME="Gradle"
APP_BASE_NAME=${0##*/}
@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
warn() {
echo "$*"
warn () {
echo "$*"
} >&2
die() {
echo
echo "$*"
echo
exit 1
die () {
echo
echo "$*"
echo
exit 1
} >&2
# OS specific support (must be 'true' or 'false').
@ -107,52 +107,51 @@ cygwin=false
msys=false
darwin=false
nonstop=false
case "$(uname)" in #(
CYGWIN*) cygwin=true ;; #(
Darwin*) darwin=true ;; #(
MSYS* | MINGW*) msys=true ;; #(
NONSTOP*) nonstop=true ;;
case "$( uname )" in #(
CYGWIN* ) cygwin=true ;; #(
Darwin* ) darwin=true ;; #(
MSYS* | MINGW* ) msys=true ;; #(
NONSTOP* ) nonstop=true ;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ]; then
if [ -x "$JAVA_HOME/jre/sh/java" ]; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD=$JAVA_HOME/jre/sh/java
else
JAVACMD=$JAVA_HOME/bin/java
fi
if [ ! -x "$JAVACMD" ]; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD=$JAVA_HOME/jre/sh/java
else
JAVACMD=$JAVA_HOME/bin/java
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
fi
else
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then
case $MAX_FD in #(
max*)
MAX_FD=$(ulimit -H -n) ||
warn "Could not query maximum file descriptor limit"
;;
esac
case $MAX_FD in #(
'' | soft) : ;; #(
*)
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
;;
esac
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
fi
# Collect all arguments for the java command, stacking in reverse order:
@ -164,36 +163,34 @@ fi
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# For Cygwin or MSYS, switch paths to Windows format before running java
if "$cygwin" || "$msys"; then
APP_HOME=$(cygpath --path --mixed "$APP_HOME")
CLASSPATH=$(cygpath --path --mixed "$CLASSPATH")
if "$cygwin" || "$msys" ; then
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$(cygpath --unix "$JAVACMD")
JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg; do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*)
t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ]
;; #(
*) false ;;
esac
then
arg=$(cygpath --path --ignore --mixed "$arg")
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ] ;; #(
*) false ;;
esac
then
arg=$( cygpath --path --ignore --mixed "$arg" )
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
fi
# Collect all arguments for the java command;
@ -203,15 +200,10 @@ fi
# * put everything else in single quotes, so that it's not re-expanded.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1; then
die "xargs is not available"
fi
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Use "xargs" to parse quoted args.
#
@ -233,10 +225,10 @@ fi
#
eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
exec "$JAVACMD" "$@"

14
gradlew.bat vendored
View File

@ -14,7 +14,7 @@
@rem limitations under the License.
@rem
@if "%DEBUG%"=="" @echo off
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@ -25,7 +25,7 @@
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
:end
@rem End local scope for the variables with windows NT shell
if %ERRORLEVEL% equ 0 goto mainEnd
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
set EXIT_CODE=%ERRORLEVEL%
if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal