parent
1b3338f809
commit
dd151aec3f
|
@ -1,473 +1,261 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.gan;
|
||||
|
||||
import static net.brutex.ai.dnn.api.NN.dense;
|
||||
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import java.util.UUID;
|
||||
import javax.imageio.ImageIO;
|
||||
import javax.swing.ImageIcon;
|
||||
import javax.swing.JFrame;
|
||||
import javax.swing.JLabel;
|
||||
import javax.swing.JPanel;
|
||||
import javax.swing.WindowConstants;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import javax.swing.*;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
import org.datavec.image.transform.ColorConversionTransform;
|
||||
import org.datavec.image.transform.ImageTransform;
|
||||
import org.datavec.image.transform.PipelineImageTransform;
|
||||
import org.datavec.image.transform.ResizeImageTransform;
|
||||
import org.datavec.image.transform.ShowImageTransform;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
@Slf4j
|
||||
public class App {
|
||||
private static final double LEARNING_RATE = 0.000002;
|
||||
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||
private static final double LEARNING_RATE = 0.002;
|
||||
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||
private static final int BATCHSIZE = 128;
|
||||
private static JFrame frame;
|
||||
private static JPanel panel;
|
||||
|
||||
private static final int X_DIM = 20 ;
|
||||
private static final int Y_DIM = 20;
|
||||
private static final int CHANNELS = 1;
|
||||
private static final int batchSize = 1;
|
||||
private static final int INPUT = 10;
|
||||
private static LayerConfiguration[] genLayers() {
|
||||
return new LayerConfiguration[] {
|
||||
dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
dense().nIn(256).nOut(512).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
dense().nIn(512).nOut(1024).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
|
||||
};
|
||||
}
|
||||
|
||||
private static final int OUTPUT_PER_PANEL = 16;
|
||||
/**
|
||||
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
|
||||
*
|
||||
* @return config
|
||||
*/
|
||||
private static NeuralNetConfiguration generator() {
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.IDENTITY)
|
||||
.layersFromArray(genLayers())
|
||||
.name("generator")
|
||||
.build();
|
||||
|
||||
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
|
||||
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||
return conf;
|
||||
}
|
||||
|
||||
private static JFrame frame;
|
||||
private static JFrame frame2;
|
||||
private static JPanel panel;
|
||||
private static JPanel panel2;
|
||||
private static LayerConfiguration[] disLayers() {
|
||||
return new LayerConfiguration[]{
|
||||
dense().nIn(784).nOut(1024).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
dense().nIn(1024).nOut(512).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
dense().nIn(512).nOut(256).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||
};
|
||||
}
|
||||
|
||||
private static final String OUTPUT_DIR = "C:/temp/output/";
|
||||
private static NeuralNetConfiguration discriminator() {
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.IDENTITY)
|
||||
.layersFromArray(disLayers())
|
||||
.name("discriminator")
|
||||
.build();
|
||||
|
||||
private static LayerConfiguration[] genLayers() {
|
||||
return new LayerConfiguration[] {
|
||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
||||
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
||||
return conf;
|
||||
}
|
||||
|
||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
private static NeuralNetConfiguration gan() {
|
||||
LayerConfiguration[] genLayers = genLayers();
|
||||
LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
|
||||
.map((layer) -> {
|
||||
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||
return FrozenLayerWithBackprop.builder(layer).build();
|
||||
} else {
|
||||
return layer;
|
||||
}
|
||||
}).toArray(LayerConfiguration[]::new);
|
||||
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
||||
|
||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
|
||||
};
|
||||
}
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.IDENTITY)
|
||||
.layersFromArray(layers)
|
||||
.name("GAN")
|
||||
.build();
|
||||
|
||||
/**
|
||||
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
|
||||
*
|
||||
* @return config
|
||||
*/
|
||||
private static NeuralNetConfiguration generator() {
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||
//.weightInit(WeightInit.XAVIER)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.IDENTITY)
|
||||
.layersFromArray(genLayers())
|
||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
||||
.build();
|
||||
((NeuralNetConfiguration) conf).init();
|
||||
return conf;
|
||||
}
|
||||
|
||||
return conf;
|
||||
}
|
||||
@Test
|
||||
public void runTest() throws Exception {
|
||||
App.main(null);
|
||||
}
|
||||
public static void main(String... args) throws Exception {
|
||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||
|
||||
private static LayerConfiguration[] disLayers() {
|
||||
return new LayerConfiguration[]{
|
||||
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||
|
||||
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
|
||||
};
|
||||
}
|
||||
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
||||
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
||||
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
||||
gen.init();
|
||||
dis.init();
|
||||
gan.init();
|
||||
|
||||
private static NeuralNetConfiguration discriminator() {
|
||||
copyParams(gen, dis, gan);
|
||||
|
||||
NeuralNetConfiguration conf =
|
||||
NeuralNetConfiguration.builder()
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||
.weightNoise(null)
|
||||
// .weightInitFn(new WeightInitXavier())
|
||||
// .activationFn(new ActivationIdentity())
|
||||
.activation(Activation.IDENTITY)
|
||||
.layersFromArray(disLayers())
|
||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||
.build();
|
||||
((NeuralNetConfiguration) conf).init();
|
||||
gen.addTrainingListeners(new PerformanceListener(10, true));
|
||||
dis.addTrainingListeners(new PerformanceListener(10, true));
|
||||
gan.addTrainingListeners(new PerformanceListener(10, true));
|
||||
|
||||
return conf;
|
||||
}
|
||||
trainData.reset();
|
||||
|
||||
private static NeuralNetConfiguration gan() {
|
||||
LayerConfiguration[] genLayers = genLayers();
|
||||
LayerConfiguration[] disLayers = Arrays.stream(disLayers())
|
||||
.map((layer) -> {
|
||||
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
||||
} else {
|
||||
return layer;
|
||||
}
|
||||
}).toArray(LayerConfiguration[]::new);
|
||||
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
||||
int j = 0;
|
||||
for (int i = 0; i < 50; i++) {
|
||||
while (trainData.hasNext()) {
|
||||
j++;
|
||||
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.seed(42)
|
||||
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
|
||||
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold( 100 )
|
||||
//.weightInitFn( new WeightInitXavier() ) //this is internal
|
||||
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||
.weightInit( WeightInit.XAVIER)
|
||||
//.activationFn( new ActivationIdentity()) //this is internal
|
||||
.activation( Activation.IDENTITY )
|
||||
.layersFromArray( layers )
|
||||
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||
.dataType(DataType.FLOAT)
|
||||
.build();
|
||||
((NeuralNetConfiguration) conf).init();
|
||||
return conf;
|
||||
}
|
||||
// generate data
|
||||
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
|
||||
int batchSize = (int) real.shape()[0];
|
||||
|
||||
INDArray fakeIn = Nd4j.rand(batchSize, 100);
|
||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||
|
||||
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
||||
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
||||
|
||||
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||
|
||||
dis.fit(data);
|
||||
dis.fit(data);
|
||||
|
||||
// Update the discriminator in the GAN network
|
||||
updateGan(gen, dis, gan);
|
||||
|
||||
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
|
||||
|
||||
|
||||
@Test
|
||||
public void runTest() throws Exception {
|
||||
if(! log.isDebugEnabled()) {
|
||||
log.info("Logging is not set to DEBUG");
|
||||
}
|
||||
else {
|
||||
log.info("Logging is set to DEBUG");
|
||||
}
|
||||
main();
|
||||
}
|
||||
if (j % 10 == 1) {
|
||||
System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
|
||||
INDArray[] samples = new INDArray[9];
|
||||
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
||||
|
||||
public static void main(String... args) throws Exception {
|
||||
for (int k = 0; k < 9; k++) {
|
||||
INDArray input = fakeSet2.get(k).getFeatures();
|
||||
//samples[k] = gen.output(input, false);
|
||||
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||
|
||||
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
|
||||
Nd4j.getMemoryManager().setAutoGcWindow(500);
|
||||
|
||||
//MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
|
||||
//FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
|
||||
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
|
||||
|
||||
|
||||
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||
|
||||
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
|
||||
|
||||
ImageTransform tr = new PipelineImageTransform.Builder()
|
||||
//.addImageTransform(transform) //convert to GREY SCALE
|
||||
.addImageTransform(transform3)
|
||||
//.addImageTransform(transform2)
|
||||
.build();
|
||||
|
||||
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
|
||||
imageRecordReader.initialize(fileSplit, tr);
|
||||
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
|
||||
|
||||
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
||||
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
||||
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
||||
gen.init(); log.debug("Generator network: {}", gen);
|
||||
dis.init(); log.debug("Discriminator network: {}", dis);
|
||||
gan.init(); log.info("Complete GAN network: {}", gan);
|
||||
|
||||
|
||||
copyParams(gen, dis, gan);
|
||||
|
||||
//gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
|
||||
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
||||
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
||||
//dis.setListeners(new ScoreToChartListener("dis"));
|
||||
|
||||
//System.out.println(gan.toString());
|
||||
//gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
||||
|
||||
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
|
||||
//trainData.reset();
|
||||
|
||||
int j = 0;
|
||||
for (int i = 0; i < 51; i++) { //epoch
|
||||
while (trainData.hasNext()) {
|
||||
j++;
|
||||
|
||||
DataSet next = trainData.next();
|
||||
// generate data
|
||||
INDArray real = next.getFeatures();//.div(255f);
|
||||
|
||||
//start next round if there are not enough images left to have a full batchsize dataset
|
||||
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
|
||||
log.warn("Your total number of input images is not a multiple of {}, "
|
||||
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||
break;
|
||||
}
|
||||
visualize(samples);
|
||||
}
|
||||
}
|
||||
trainData.reset();
|
||||
// Copy the GANs generator to gen.
|
||||
//updateGen(gen, gan);
|
||||
}
|
||||
|
||||
//if(i%20 == 0) {
|
||||
frame2 = visualize(new INDArray[]{real}, batchSize,
|
||||
frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
||||
//}
|
||||
real.divi(255f);
|
||||
|
||||
// int batchSize = (int) real.shape()[0];
|
||||
|
||||
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
|
||||
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
|
||||
|
||||
|
||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
|
||||
|
||||
//log.info("real has {} items.", real.length());
|
||||
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
||||
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
||||
|
||||
|
||||
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||
|
||||
dis.fit(data);
|
||||
//dis.fit(data);
|
||||
|
||||
// Update the discriminator in the GAN network
|
||||
updateGan(gen, dis, gan);
|
||||
|
||||
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
|
||||
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
|
||||
|
||||
//Visualize and reporting
|
||||
if (j % 10 == 1) {
|
||||
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
|
||||
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
|
||||
|
||||
|
||||
for (int k = 0; k < samples.length; k++) {
|
||||
//INDArray input = fakeSet2.get(k).getFeatures();
|
||||
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
||||
INDArray input = fakeSet2.get(k).getFeatures();
|
||||
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
|
||||
|
||||
//samples[k] = gen.output(input, false);
|
||||
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
|
||||
//samples[k] =
|
||||
samples[k].addi(1f).divi(2f).muli(255f);
|
||||
|
||||
}
|
||||
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
||||
}
|
||||
}
|
||||
if (trainData.resetSupported()) {
|
||||
trainData.reset();
|
||||
} else {
|
||||
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||
}
|
||||
// Copy the GANs generator to gen.
|
||||
updateGen(gen, gan);
|
||||
|
||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||
int genLayerCount = gen.getLayers().length;
|
||||
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||
if (i < genLayerCount) {
|
||||
if(gan.getLayer(i).getParams() != null)
|
||||
gan.getLayer(i).setParams(gen.getLayer(i).getParams());
|
||||
} else {
|
||||
if(gan.getLayer(i).getParams() != null)
|
||||
gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||
}
|
||||
}
|
||||
|
||||
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||
int genLayerCount = gen.getLayers().length;
|
||||
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||
}
|
||||
}
|
||||
|
||||
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
||||
if (isOrig) {
|
||||
frame.setTitle("Viz Original");
|
||||
} else {
|
||||
frame.setTitle("Generated");
|
||||
}
|
||||
|
||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||
frame.setLayout(new BorderLayout());
|
||||
|
||||
JPanel panelx = new JPanel();
|
||||
|
||||
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
||||
for (INDArray sample : samples) {
|
||||
for(int i = 0; i<batchElements; i++) {
|
||||
panelx.add(getImage(sample, i, isOrig));
|
||||
}
|
||||
}
|
||||
frame.add(panelx, BorderLayout.CENTER);
|
||||
frame.setVisible(true);
|
||||
|
||||
frame.revalidate();
|
||||
frame.setMinimumSize(new Dimension(300, 20));
|
||||
frame.pack();
|
||||
return frame;
|
||||
}
|
||||
|
||||
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||
final BufferedImage bi;
|
||||
if(CHANNELS>1) {
|
||||
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
||||
} else {
|
||||
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
||||
}
|
||||
final int imageSize = X_DIM * Y_DIM;
|
||||
final int offset = batchElement * imageSize;
|
||||
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||
|
||||
//Image in NCHW - channels first format
|
||||
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
|
||||
for (int x = 0; x < X_DIM; x++) { //step through the rows y
|
||||
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
|
||||
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
|
||||
pxl++; //next item in INDArray
|
||||
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||
int genLayerCount = gen.getLayers().length;
|
||||
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||
if (i < genLayerCount) {
|
||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||
} else {
|
||||
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ImageIcon orig = new ImageIcon(bi);
|
||||
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||
return new JLabel(scaled);
|
||||
|
||||
}
|
||||
|
||||
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||
|
||||
try {
|
||||
// Save the images to disk
|
||||
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||
|
||||
log.debug("Images saved successfully.");
|
||||
} catch (IOException e) {
|
||||
log.error("Error saving the images: {}", e.getMessage());
|
||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||
}
|
||||
}
|
||||
}
|
||||
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||
File directory = new File(outputDirectory);
|
||||
if (!directory.exists()) {
|
||||
directory.mkdir();
|
||||
|
||||
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||
int genLayerCount = gen.getLayers().length;
|
||||
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||
}
|
||||
}
|
||||
|
||||
private static void visualize(INDArray[] samples) {
|
||||
if (frame == null) {
|
||||
frame = new JFrame();
|
||||
frame.setTitle("Viz");
|
||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||
frame.setLayout(new BorderLayout());
|
||||
|
||||
panel = new JPanel();
|
||||
|
||||
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
|
||||
frame.add(panel, BorderLayout.CENTER);
|
||||
frame.setVisible(true);
|
||||
}
|
||||
|
||||
File outputFile = new File(directory, fileName);
|
||||
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||
}
|
||||
panel.removeAll();
|
||||
|
||||
public static BufferedImage imageToBufferedImage(Image image) {
|
||||
if (image instanceof BufferedImage) {
|
||||
return (BufferedImage) image;
|
||||
for (INDArray sample : samples) {
|
||||
panel.add(getImage(sample));
|
||||
}
|
||||
|
||||
// Create a buffered image with the same dimensions and transparency as the original image
|
||||
BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||
|
||||
// Draw the original image onto the buffered image
|
||||
Graphics2D g2d = bufferedImage.createGraphics();
|
||||
g2d.drawImage(image, 0, 0, null);
|
||||
g2d.dispose();
|
||||
|
||||
return bufferedImage;
|
||||
frame.revalidate();
|
||||
frame.pack();
|
||||
}
|
||||
|
||||
private static JLabel getImage(INDArray tensor) {
|
||||
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
||||
for (int i = 0; i < 784; i++) {
|
||||
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
|
||||
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
||||
}
|
||||
ImageIcon orig = new ImageIcon(bi);
|
||||
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
||||
|
||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||
|
||||
return new JLabel(scaled);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,343 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.gan;
|
||||
|
||||
|
||||
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
import javax.imageio.ImageIO;
|
||||
import javax.swing.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
import org.datavec.image.transform.*;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
@Slf4j
|
||||
public class App2 {
|
||||
|
||||
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
||||
static final float COLORSPACE = 255f;
|
||||
static final int DIMENSIONS = 28;
|
||||
static final int CHANNELS = 1;
|
||||
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
||||
final int OUTPUT_PER_PANEL = 10;
|
||||
|
||||
final boolean BIAS = true;
|
||||
|
||||
static final int BATCHSIZE=128;
|
||||
|
||||
private JFrame frame2, frame;
|
||||
static final String OUTPUT_DIR = "d:/out/";
|
||||
|
||||
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
|
||||
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
|
||||
|
||||
@Test
|
||||
void runTest() throws IOException {
|
||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||
|
||||
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
||||
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
|
||||
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
||||
|
||||
ImageTransform tr = new PipelineImageTransform.Builder()
|
||||
.addImageTransform(transform) //convert to GREY SCALE
|
||||
.addImageTransform(transform3)
|
||||
//.addImageTransform(transform2)
|
||||
.build();
|
||||
|
||||
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
|
||||
imageRecordReader.initialize(fileSplit, tr);
|
||||
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
|
||||
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||
|
||||
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
|
||||
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
|
||||
|
||||
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
|
||||
.map((layer) -> {
|
||||
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
||||
} else {
|
||||
return layer;
|
||||
}
|
||||
}).toArray(LayerConfiguration[]::new);
|
||||
|
||||
NeuralNetConfiguration netConfiguration =
|
||||
NeuralNetConfiguration.builder()
|
||||
.name("GAN")
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(100)
|
||||
.updater(App2Config.UPDATER)
|
||||
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
|
||||
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
|
||||
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
|
||||
//.dataType(DataType.FLOAT)
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
|
||||
|
||||
dis.init(); log.debug("Discriminator network: {}", dis);
|
||||
gen.init(); log.debug("Generator network: {}", gen);
|
||||
gan.init(); log.debug("GAN network: {}", gan);
|
||||
|
||||
|
||||
log.info("Generator Summary:\n{}", gen.summary());
|
||||
log.info("GAN Summary:\n{}", gan.summary());
|
||||
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
|
||||
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
||||
|
||||
int j = 0;
|
||||
for (int i = 0; i < 51; i++) { //epoch
|
||||
while (trainData.hasNext()) {
|
||||
j++;
|
||||
DataSet next = trainData.next();
|
||||
// generate data
|
||||
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
|
||||
|
||||
//start next round if there are not enough images left to have a full batchsize dataset
|
||||
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
|
||||
log.warn("Your total number of input images is not a multiple of {}, "
|
||||
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||
break;
|
||||
}
|
||||
|
||||
//if(i%20 == 0) {
|
||||
|
||||
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
|
||||
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
||||
//}
|
||||
//real.divi(255f);
|
||||
|
||||
// int batchSize = (int) real.shape()[0];
|
||||
|
||||
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
|
||||
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
|
||||
|
||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||
// when generator has TANH as activation - value range is -1 to 1
|
||||
// when generator has SIGMOID, then range is 0 to 1
|
||||
fake.addi(1f).divi(2f);
|
||||
|
||||
DataSet realSet = new DataSet(real, label_real);
|
||||
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||
|
||||
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||
|
||||
dis.fit(data);
|
||||
dis.fit(data);
|
||||
// Update the discriminator in the GAN network
|
||||
updateGan(gen, dis, gan);
|
||||
|
||||
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
|
||||
|
||||
//Visualize and reporting
|
||||
if (j % 10 == 1) {
|
||||
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
|
||||
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
|
||||
|
||||
|
||||
for (int k = 0; k < samples.length; k++) {
|
||||
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
|
||||
INDArray input = fakeSet2.get(k).getFeatures();
|
||||
|
||||
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
|
||||
input = input.reshape(1, App2Config.INPUT);
|
||||
|
||||
//samples[k] = gen.output(input, false);
|
||||
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||
//samples[k] =
|
||||
//samples[k].muli(255f);
|
||||
|
||||
}
|
||||
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
||||
}
|
||||
}
|
||||
|
||||
if (trainData.resetSupported()) {
|
||||
trainData.reset();
|
||||
} else {
|
||||
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||
}
|
||||
// Copy the GANs generator to gen.
|
||||
updateGen(gen, gan);
|
||||
log.info("Updated GAN's generator from gen.");
|
||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
||||
if (isOrig) {
|
||||
frame.setTitle("Viz Original");
|
||||
} else {
|
||||
frame.setTitle("Generated");
|
||||
}
|
||||
|
||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||
frame.setLayout(new BorderLayout());
|
||||
|
||||
JPanel panelx = new JPanel();
|
||||
|
||||
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
||||
for (INDArray sample : samples) {
|
||||
for(int i = 0; i<batchElements; i++) {
|
||||
panelx.add(getImage(sample, i, isOrig));
|
||||
}
|
||||
}
|
||||
frame.add(panelx, BorderLayout.CENTER);
|
||||
frame.setVisible(true);
|
||||
|
||||
frame.revalidate();
|
||||
frame.setMinimumSize(new Dimension(300, 20));
|
||||
frame.pack();
|
||||
return frame;
|
||||
}
|
||||
|
||||
|
||||
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||
final BufferedImage bi;
|
||||
if(CHANNELS >1) {
|
||||
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
||||
} else {
|
||||
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
||||
}
|
||||
final int imageSize = DIMENSIONS * DIMENSIONS;
|
||||
final int offset = batchElement * imageSize;
|
||||
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||
|
||||
//Image in NCHW - channels first format
|
||||
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
|
||||
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
|
||||
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
||||
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
|
||||
bi.getRaster().setSample(x, y, c, f_pxl);
|
||||
pxl++; //next item in INDArray
|
||||
}
|
||||
}
|
||||
}
|
||||
ImageIcon orig = new ImageIcon(bi);
|
||||
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
|
||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||
return new JLabel(scaled);
|
||||
|
||||
}
|
||||
|
||||
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||
|
||||
try {
|
||||
// Save the images to disk
|
||||
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||
|
||||
log.debug("Images saved successfully.");
|
||||
} catch (IOException e) {
|
||||
log.error("Error saving the images: {}", e.getMessage());
|
||||
}
|
||||
}
|
||||
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||
File directory = new File(outputDirectory);
|
||||
if (!directory.exists()) {
|
||||
directory.mkdir();
|
||||
}
|
||||
|
||||
File outputFile = new File(directory, fileName);
|
||||
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||
}
|
||||
|
||||
public static BufferedImage imageToBufferedImage(Image image) {
|
||||
if (image instanceof BufferedImage) {
|
||||
return (BufferedImage) image;
|
||||
}
|
||||
|
||||
// Create a buffered image with the same dimensions and transparency as the original image
|
||||
BufferedImage bufferedImage;
|
||||
if (CHANNELS > 1) {
|
||||
bufferedImage =
|
||||
new BufferedImage(
|
||||
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||
} else {
|
||||
bufferedImage =
|
||||
new BufferedImage(
|
||||
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
||||
}
|
||||
|
||||
// Draw the original image onto the buffered image
|
||||
Graphics2D g2d = bufferedImage.createGraphics();
|
||||
g2d.drawImage(image, 0, 0, null);
|
||||
g2d.dispose();
|
||||
|
||||
return bufferedImage;
|
||||
}
|
||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||
}
|
||||
}
|
||||
|
||||
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||
int genLayerCount = gen.getLayers().length;
|
||||
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,176 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.gan;
|
||||
|
||||
import static net.brutex.ai.dnn.api.NN.*;
|
||||
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
public class App2Config {
|
||||
|
||||
public static final int INPUT = 100;
|
||||
public static final int X_DIM = 28;
|
||||
public static final int y_DIM = 28;
|
||||
public static final int CHANNELS = 1;
|
||||
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
||||
|
||||
static LayerConfiguration[] genLayerConfig() {
|
||||
return new LayerConfiguration[] {
|
||||
/*
|
||||
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
|
||||
ActivationLayer.builder().activation(Activation.RELU).build(), /*
|
||||
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
|
||||
.kernelSize(2,2)
|
||||
.stride(1,1)
|
||||
.padding(0,0)
|
||||
.convolutionMode(ConvolutionMode.Truncate)
|
||||
.activation(Activation.RELU)
|
||||
.hasBias(BIAS).build(),
|
||||
//BatchNormalization.builder().nOut(CHANNELS).build(),
|
||||
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
|
||||
.kernelSize(2,2)
|
||||
.stride(2,2)
|
||||
.padding(0,0)
|
||||
.convolutionMode(ConvolutionMode.Truncate)
|
||||
.activation(Activation.RELU)
|
||||
.hasBias(BIAS).build(),
|
||||
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
|
||||
|
||||
|
||||
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
|
||||
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
|
||||
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
|
||||
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||
// DropoutLayer.builder(0.001).build(),
|
||||
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
|
||||
|
||||
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
dense().nIn(256).nOut(512).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
dense().nIn(512).nOut(1024).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
static LayerConfiguration[] disLayerConfig() {
|
||||
return new LayerConfiguration[] {/*
|
||||
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
|
||||
.build(),
|
||||
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
|
||||
.build(),
|
||||
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
|
||||
BatchNormalization.builder().build(),
|
||||
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SIGMOID)
|
||||
.build()
|
||||
|
||||
|
||||
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
|
||||
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||
DropoutLayer.builder(0.5).build(),
|
||||
|
||||
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
|
||||
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||
DropoutLayer.builder(0.5).build(),
|
||||
|
||||
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
|
||||
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||
DropoutLayer.builder(0.5).build(),
|
||||
|
||||
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
|
||||
.activation(Activation.SIGMOID)
|
||||
.build() */
|
||||
dense().nIn(784).nOut(1024).hasBias(true).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
dense().nIn(1024).nOut(512).hasBias(true).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
dense().nIn(512).nOut(256).hasBias(true).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
static NeuralNetConfiguration generator() {
|
||||
NeuralNetConfiguration conf =
|
||||
NeuralNetConfiguration.builder()
|
||||
.name("generator")
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(100)
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||
.weightNoise(null)
|
||||
// .weightInitFn(new WeightInitXavier())
|
||||
// .activationFn(new ActivationIdentity())
|
||||
.activation(Activation.IDENTITY)
|
||||
.layersFromArray(App2Config.genLayerConfig())
|
||||
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||
|
||||
.build();
|
||||
conf.init();
|
||||
return conf;
|
||||
}
|
||||
|
||||
static NeuralNetConfiguration discriminator() {
|
||||
NeuralNetConfiguration conf =
|
||||
NeuralNetConfiguration.builder()
|
||||
.name("discriminator")
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
.gradientNormalizationThreshold(100)
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||
.weightNoise(null)
|
||||
// .weightInitFn(new WeightInitXavier())
|
||||
// .activationFn(new ActivationIdentity())
|
||||
.activation(Activation.IDENTITY)
|
||||
.layersFromArray(disLayerConfig())
|
||||
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||
//.dataType(DataType.FLOAT)
|
||||
.build();
|
||||
conf.init();
|
||||
return conf;
|
||||
}
|
||||
}
|
|
@ -43,223 +43,255 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.impo
|
|||
@Slf4j
|
||||
public class KerasSequentialModel extends KerasModel {
|
||||
|
||||
/**
|
||||
* (Recommended) Builder-pattern constructor for Sequential model.
|
||||
*
|
||||
* @param modelBuilder builder object
|
||||
* @throws IOException I/O exception
|
||||
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
||||
*/
|
||||
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
||||
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
|
||||
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
|
||||
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
|
||||
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
|
||||
/**
|
||||
* (Recommended) Builder-pattern constructor for Sequential model.
|
||||
*
|
||||
* @param modelBuilder builder object
|
||||
* @throws IOException I/O exception
|
||||
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
||||
*/
|
||||
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
||||
throws UnsupportedKerasConfigurationException,
|
||||
IOException,
|
||||
InvalidKerasConfigurationException {
|
||||
this(
|
||||
modelBuilder.getModelJson(),
|
||||
modelBuilder.getModelYaml(),
|
||||
modelBuilder.getWeightsArchive(),
|
||||
modelBuilder.getWeightsRoot(),
|
||||
modelBuilder.getTrainingJson(),
|
||||
modelBuilder.getTrainingArchive(),
|
||||
modelBuilder.isEnforceTrainingConfig(),
|
||||
modelBuilder.getInputShape());
|
||||
}
|
||||
|
||||
/**
|
||||
* (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
|
||||
* training configuration (JSON), weights, and "training mode" boolean indicator. When built in
|
||||
* training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
|
||||
* Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
|
||||
* otherwise ignored.
|
||||
*
|
||||
* @param modelJson model configuration JSON string
|
||||
* @param modelYaml model configuration YAML string
|
||||
* @param trainingJson training configuration JSON string
|
||||
* @throws IOException I/O exception
|
||||
*/
|
||||
public KerasSequentialModel(
|
||||
String modelJson,
|
||||
String modelYaml,
|
||||
Hdf5Archive weightsArchive,
|
||||
String weightsRoot,
|
||||
String trainingJson,
|
||||
Hdf5Archive trainingArchive,
|
||||
boolean enforceTrainingConfig,
|
||||
int[] inputShape)
|
||||
throws IOException,
|
||||
InvalidKerasConfigurationException,
|
||||
UnsupportedKerasConfigurationException {
|
||||
|
||||
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
||||
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
||||
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
|
||||
this.enforceTrainingConfig = enforceTrainingConfig;
|
||||
|
||||
/* Determine model configuration type. */
|
||||
if (!modelConfig.containsKey(config.getFieldClassName()))
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Could not determine Keras model class (no "
|
||||
+ config.getFieldClassName()
|
||||
+ " field found)");
|
||||
this.className = (String) modelConfig.get(config.getFieldClassName());
|
||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Model class name must be "
|
||||
+ config.getFieldClassNameSequential()
|
||||
+ " (found "
|
||||
+ this.className
|
||||
+ ")");
|
||||
|
||||
/* Process layer configurations. */
|
||||
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Could not find layer configurations (no "
|
||||
+ config.getModelFieldConfig()
|
||||
+ " field found)");
|
||||
|
||||
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
|
||||
// For consistency
|
||||
// "config" is now an object containing a "name" and "layers", the latter contain the same data
|
||||
// as before.
|
||||
// This change only affects Sequential models.
|
||||
List<Object> layerList;
|
||||
try {
|
||||
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
|
||||
} catch (Exception e) {
|
||||
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
|
||||
layerList = (List<Object>) layerMap.get("layers");
|
||||
}
|
||||
|
||||
/**
|
||||
* (Not recommended) Constructor for Sequential model from model configuration
|
||||
* (JSON or YAML), training configuration (JSON), weights, and "training mode"
|
||||
* boolean indicator. When built in training mode, certain unsupported configurations
|
||||
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
|
||||
* will generate warnings but will be otherwise ignored.
|
||||
*
|
||||
* @param modelJson model configuration JSON string
|
||||
* @param modelYaml model configuration YAML string
|
||||
* @param trainingJson training configuration JSON string
|
||||
* @throws IOException I/O exception
|
||||
*/
|
||||
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
|
||||
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
|
||||
int[] inputShape)
|
||||
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
|
||||
this.layers = layerPair.getFirst();
|
||||
this.layersOrdered = layerPair.getSecond();
|
||||
|
||||
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
||||
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
||||
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
|
||||
this.enforceTrainingConfig = enforceTrainingConfig;
|
||||
KerasLayer inputLayer;
|
||||
if (this.layersOrdered.get(0) instanceof KerasInput) {
|
||||
inputLayer = this.layersOrdered.get(0);
|
||||
} else {
|
||||
/* Add placeholder input layer and update lists of input and output layers. */
|
||||
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
||||
Preconditions.checkState(
|
||||
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
|
||||
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
||||
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
||||
this.layers.put(inputLayer.getName(), inputLayer);
|
||||
this.layersOrdered.add(0, inputLayer);
|
||||
}
|
||||
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
||||
this.outputLayerNames =
|
||||
new ArrayList<>(
|
||||
Collections.singletonList(
|
||||
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
||||
|
||||
/* Determine model configuration type. */
|
||||
if (!modelConfig.containsKey(config.getFieldClassName()))
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
|
||||
this.className = (String) modelConfig.get(config.getFieldClassName());
|
||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
|
||||
+ " (found " + this.className + ")");
|
||||
|
||||
/* Process layer configurations. */
|
||||
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
|
||||
|
||||
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
|
||||
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
|
||||
// This change only affects Sequential models.
|
||||
List<Object> layerList;
|
||||
try {
|
||||
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
|
||||
} catch (Exception e) {
|
||||
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
|
||||
layerList = (List<Object>) layerMap.get("layers");
|
||||
}
|
||||
|
||||
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
|
||||
prepareLayers(layerList);
|
||||
this.layers = layerPair.getFirst();
|
||||
this.layersOrdered = layerPair.getSecond();
|
||||
|
||||
KerasLayer inputLayer;
|
||||
if (this.layersOrdered.get(0) instanceof KerasInput) {
|
||||
inputLayer = this.layersOrdered.get(0);
|
||||
} else {
|
||||
/* Add placeholder input layer and update lists of input and output layers. */
|
||||
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
||||
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
|
||||
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
||||
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
||||
this.layers.put(inputLayer.getName(), inputLayer);
|
||||
this.layersOrdered.add(0, inputLayer);
|
||||
}
|
||||
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
||||
this.outputLayerNames = new ArrayList<>(
|
||||
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
||||
|
||||
/* Update each layer's inbound layer list to include (only) previous layer. */
|
||||
KerasLayer prevLayer = null;
|
||||
for (KerasLayer layer : this.layersOrdered) {
|
||||
if (prevLayer != null)
|
||||
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
|
||||
prevLayer = layer;
|
||||
}
|
||||
|
||||
/* Import training configuration. */
|
||||
if (enforceTrainingConfig) {
|
||||
if (trainingJson != null)
|
||||
importTrainingConfiguration(trainingJson);
|
||||
else log.warn("If enforceTrainingConfig is true, a training " +
|
||||
"configuration object has to be provided. Usually the only practical way to do this is to store" +
|
||||
" your keras model with `model.save('model_path.h5'. If you store model config and weights" +
|
||||
" separately no training configuration is attached.");
|
||||
}
|
||||
|
||||
this.outputTypes = inferOutputTypes(inputShape);
|
||||
|
||||
if (weightsArchive != null)
|
||||
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
||||
/* Update each layer's inbound layer list to include (only) previous layer. */
|
||||
KerasLayer prevLayer = null;
|
||||
for (KerasLayer layer : this.layersOrdered) {
|
||||
if (prevLayer != null)
|
||||
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
|
||||
prevLayer = layer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default constructor
|
||||
*/
|
||||
public KerasSequentialModel() {
|
||||
super();
|
||||
/* Import training configuration. */
|
||||
if (enforceTrainingConfig) {
|
||||
if (trainingJson != null) importTrainingConfiguration(trainingJson);
|
||||
else
|
||||
log.warn(
|
||||
"If enforceTrainingConfig is true, a training "
|
||||
+ "configuration object has to be provided. Usually the only practical way to do this is to store"
|
||||
+ " your keras model with `model.save('model_path.h5'. If you store model config and weights"
|
||||
+ " separately no training configuration is attached.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
|
||||
*
|
||||
* @return NeuralNetConfiguration
|
||||
*/
|
||||
public NeuralNetConfiguration getNeuralNetConfiguration()
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
|
||||
if (this.inputLayerNames.size() != 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
|
||||
if (this.outputLayerNames.size() != 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
||||
this.outputTypes = inferOutputTypes(inputShape);
|
||||
|
||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
|
||||
if (weightsArchive != null)
|
||||
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
||||
}
|
||||
|
||||
if (optimizer != null) {
|
||||
modelBuilder.updater(optimizer);
|
||||
/** Default constructor */
|
||||
public KerasSequentialModel() {
|
||||
super();
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
|
||||
*
|
||||
* @return NeuralNetConfiguration
|
||||
*/
|
||||
public NeuralNetConfiguration getNeuralNetConfiguration()
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
|
||||
if (this.inputLayerNames.size() != 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
|
||||
if (this.outputLayerNames.size() != 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
||||
|
||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
|
||||
NeuralNetConfiguration.builder();
|
||||
|
||||
if (optimizer != null) {
|
||||
modelBuilder.updater(optimizer);
|
||||
}
|
||||
|
||||
// don't forcibly override for keras import
|
||||
modelBuilder.overrideNinUponBuild(false);
|
||||
/* Add layers one at a time. */
|
||||
KerasLayer prevLayer = null;
|
||||
int layerIndex = 0;
|
||||
for (KerasLayer layer : this.layersOrdered) {
|
||||
if (layer.isLayer()) {
|
||||
int nbInbound = layer.getInboundLayerNames().size();
|
||||
if (nbInbound != 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
||||
+ nbInbound
|
||||
+ " for layer "
|
||||
+ layer.getName()
|
||||
+ ")");
|
||||
if (prevLayer != null) {
|
||||
InputType[] inputTypes = new InputType[1];
|
||||
InputPreProcessor preprocessor;
|
||||
if (prevLayer.isInputPreProcessor()) {
|
||||
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
||||
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||
} else {
|
||||
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
||||
preprocessor = layer.getInputPreprocessor(inputTypes);
|
||||
if (preprocessor != null) {
|
||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
|
||||
}
|
||||
if (preprocessor != null) {
|
||||
|
||||
Map<Integer, InputPreProcessor> map = new HashMap<>();
|
||||
map.put(layerIndex, preprocessor);
|
||||
modelBuilder.inputPreProcessors(map);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//don't forcibly override for keras import
|
||||
modelBuilder.overrideNinUponBuild(false);
|
||||
/* Add layers one at a time. */
|
||||
KerasLayer prevLayer = null;
|
||||
int layerIndex = 0;
|
||||
for (KerasLayer layer : this.layersOrdered) {
|
||||
if (layer.isLayer()) {
|
||||
int nbInbound = layer.getInboundLayerNames().size();
|
||||
if (nbInbound != 1)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
||||
+ nbInbound + " for layer " + layer.getName() + ")");
|
||||
if (prevLayer != null) {
|
||||
InputType[] inputTypes = new InputType[1];
|
||||
InputPreProcessor preprocessor;
|
||||
if (prevLayer.isInputPreProcessor()) {
|
||||
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
||||
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
||||
} else {
|
||||
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
||||
preprocessor = layer.getInputPreprocessor(inputTypes);
|
||||
if(preprocessor != null) {
|
||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
||||
}
|
||||
else
|
||||
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
|
||||
|
||||
}
|
||||
if (preprocessor != null)
|
||||
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
|
||||
|
||||
|
||||
}
|
||||
|
||||
modelBuilder.layer(layerIndex++, layer.getLayer());
|
||||
} else if (layer.getVertex() != null)
|
||||
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
|
||||
+ layer.getClassName() + ", layer name " + layer.getName() + ")");
|
||||
prevLayer = layer;
|
||||
}
|
||||
|
||||
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
||||
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
||||
modelBuilder.backpropType(BackpropType.TruncatedBPTT)
|
||||
.tbpttFwdLength(truncatedBPTT)
|
||||
.tbpttBackLength(truncatedBPTT);
|
||||
else
|
||||
modelBuilder.backpropType(BackpropType.Standard);
|
||||
|
||||
NeuralNetConfiguration build = modelBuilder.build();
|
||||
|
||||
|
||||
return build;
|
||||
modelBuilder.layer(layerIndex++, layer.getLayer());
|
||||
} else if (layer.getVertex() != null)
|
||||
throw new InvalidKerasConfigurationException(
|
||||
"Cannot add vertex to NeuralNetConfiguration (class name "
|
||||
+ layer.getClassName()
|
||||
+ ", layer name "
|
||||
+ layer.getName()
|
||||
+ ")");
|
||||
prevLayer = layer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
|
||||
*
|
||||
* @return MultiLayerNetwork
|
||||
*/
|
||||
public MultiLayerNetwork getMultiLayerNetwork()
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
return getMultiLayerNetwork(true);
|
||||
}
|
||||
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
||||
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
||||
modelBuilder
|
||||
.backpropType(BackpropType.TruncatedBPTT)
|
||||
.tbpttFwdLength(truncatedBPTT)
|
||||
.tbpttBackLength(truncatedBPTT);
|
||||
else modelBuilder.backpropType(BackpropType.Standard);
|
||||
|
||||
/**
|
||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
|
||||
*
|
||||
* @return MultiLayerNetwork
|
||||
*/
|
||||
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
|
||||
model.init();
|
||||
if (importWeights)
|
||||
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
|
||||
return model;
|
||||
}
|
||||
NeuralNetConfiguration build = modelBuilder.build();
|
||||
|
||||
return build;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
|
||||
*
|
||||
* @return MultiLayerNetwork
|
||||
*/
|
||||
public MultiLayerNetwork getMultiLayerNetwork()
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
return getMultiLayerNetwork(true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
|
||||
*
|
||||
* @return MultiLayerNetwork
|
||||
*/
|
||||
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
|
||||
model.init();
|
||||
if (importWeights)
|
||||
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
|
||||
return model;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api;
|
|||
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
|
||||
/**
|
||||
* A fluent API to configure and create artificial neural networks
|
||||
|
@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB
|
|||
public class NN {
|
||||
|
||||
|
||||
public static NeuralNetConfigurationBuilder<?, ?> net() {
|
||||
public static NeuralNetConfigurationBuilder<?, ?> nn() {
|
||||
return NeuralNetConfiguration.builder();
|
||||
}
|
||||
|
||||
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
|||
@Getter @Setter @NonNull @lombok.Builder.Default
|
||||
protected BackpropType backpropType = BackpropType.Standard;
|
||||
|
||||
@Getter @lombok.Builder.Default
|
||||
@Getter @Setter @Singular
|
||||
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
||||
/**
|
||||
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
|
||||
|
@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
|||
* @param processor what to use to preProcess the data.
|
||||
* @return builder pattern
|
||||
*/
|
||||
public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
|
||||
if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>();
|
||||
inputPreProcessors$value.put(layer, processor);
|
||||
inputPreProcessors$set = true;
|
||||
return self();
|
||||
}
|
||||
//public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
|
||||
// inputPreProcessors$value.put(layer, processor);
|
||||
// inputPreProcessors$set = true;
|
||||
// return self();
|
||||
// }
|
||||
|
||||
/**
|
||||
* Set layer at index
|
||||
|
|
|
@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
|||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.*;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
|
@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
|||
@NonNull
|
||||
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
|
||||
if (inputPreProcessor != null) {
|
||||
inputPreProcessors = new HashMap<>(inputPreProcessors);
|
||||
inputPreProcessors.put(i, inputPreProcessor);
|
||||
}
|
||||
}
|
||||
|
@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
|||
obj.getClass().getSimpleName());
|
||||
}
|
||||
});
|
||||
// make sure the indexes are sequenced properly
|
||||
AtomicInteger i = new AtomicInteger();
|
||||
ret.forEach(obj -> {
|
||||
obj.setIndex(i.getAndIncrement());
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
throw new IllegalStateException(
|
||||
"Invalid input for Convolution layer (layer name=\""
|
||||
+ getName()
|
||||
+ "\"): Expected CNN input, got "
|
||||
+ "\" at index '"+getIndex()+"') : Expected CNN input, got "
|
||||
+ inputType);
|
||||
}
|
||||
|
||||
|
@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
* @param kernelSize kernel size
|
||||
*/
|
||||
public B kernelSize(int... kernelSize) {
|
||||
this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
|
||||
//this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
|
||||
this.kernelSize$value = kernelSize;
|
||||
this.kernelSize$set = true;
|
||||
return self();
|
||||
}
|
||||
|
@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
* @param stride kernel size
|
||||
*/
|
||||
public B stride(int... stride) {
|
||||
this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
|
||||
//this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
|
||||
this.stride$value = stride;
|
||||
this.stride$set = true;
|
||||
return self();
|
||||
}
|
||||
|
@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
* @param padding kernel size
|
||||
*/
|
||||
public B padding(int... padding) {
|
||||
this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
|
||||
//this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
|
||||
this.padding$value = padding;
|
||||
this.padding$set = true;
|
||||
return self();
|
||||
}
|
||||
|
@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
* @param dilation kernel size
|
||||
*/
|
||||
public B dilation(int... dilation) {
|
||||
this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
|
||||
//this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
|
||||
this.dilation$value = dilation;
|
||||
this.dilation$set = true;
|
||||
return self();
|
||||
}
|
||||
|
|
|
@ -20,14 +20,19 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import lombok.extern.jackson.Jacksonized;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
|
||||
|
@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer {
|
|||
boolean initializeParams,
|
||||
DataType networkDataType) {
|
||||
setNetConfiguration(conf);
|
||||
|
||||
|
||||
LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
|
||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||
runInheritance();
|
||||
|
@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer {
|
|||
getName(),
|
||||
Deconvolution2DLayer.class);
|
||||
}
|
||||
|
||||
@Slf4j
|
||||
private static final class Deconvolution2DBuilderImpl
|
||||
extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
|
||||
public Deconvolution2D build() {
|
||||
Deconvolution2D l = new Deconvolution2D(this);
|
||||
if( l.getConvolutionMode() == ConvolutionMode.Same
|
||||
&& IntStream.of(l.getPadding()).sum() != 0) {
|
||||
log.warn("Invalid input for layer '{}'. "
|
||||
+ "You cannot have a padding of {} when Convolution Mode is set to 'Same'."
|
||||
+ " Padding will be ignored."
|
||||
, l.getName(), l.getPadding());
|
||||
}
|
||||
/* strides * (input_size-1) + kernel_size - 2*padding */
|
||||
//TODO: This is wrong, also depends on convolutionMode, etc ...
|
||||
/*l.nOut = l.getStride()[0] * (l.getNIn()-1)
|
||||
+ IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b)
|
||||
- 2L * IntStream.of(l.getPadding()).sum();
|
||||
*/
|
||||
//l.nOut =264;
|
||||
l.initializeConstraints();
|
||||
return l;
|
||||
}
|
||||
|
|
|
@ -62,6 +62,7 @@ public abstract class LayerConfiguration
|
|||
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
|
||||
|
||||
@Getter @Setter protected String name;
|
||||
@Getter @Setter private int index;
|
||||
@Getter @Setter protected List<LayerConstraint> allParamConstraints;
|
||||
@Getter @Setter protected List<LayerConstraint> weightConstraints;
|
||||
@Getter @Setter protected List<LayerConstraint> biasConstraints;
|
||||
|
@ -72,6 +73,7 @@ public abstract class LayerConfiguration
|
|||
/** The type of the layer, basically defines the base class and its properties */
|
||||
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
|
||||
|
||||
|
||||
/**
|
||||
* Number of parameters this layer has a result of its configuration
|
||||
* @return number or parameters
|
||||
|
@ -80,7 +82,6 @@ public abstract class LayerConfiguration
|
|||
return initializer().numParams(this);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A reference to the neural net configuration. This field is excluded from json serialization as
|
||||
* well as from equals check to avoid circular referenced.
|
||||
|
|
|
@ -37,122 +37,166 @@ import org.nd4j.linalg.api.shape.Shape;
|
|||
@Data
|
||||
@EqualsAndHashCode(exclude = {"shape"})
|
||||
public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
|
||||
private long inputHeight;
|
||||
private long inputWidth;
|
||||
private long numChannels;
|
||||
private long inputHeight;
|
||||
private long inputWidth;
|
||||
private long numChannels;
|
||||
|
||||
@Getter(AccessLevel.NONE)
|
||||
@Setter(AccessLevel.NONE)
|
||||
private long[] shape;
|
||||
@Getter(AccessLevel.NONE)
|
||||
@Setter(AccessLevel.NONE)
|
||||
private long[] shape;
|
||||
|
||||
/**
|
||||
* Reshape to a channels x rows x columns tensor
|
||||
*
|
||||
* @param inputHeight the columns
|
||||
* @param inputWidth the rows
|
||||
* @param numChannels the channels
|
||||
*/
|
||||
@JsonCreator
|
||||
public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight,
|
||||
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) {
|
||||
this.inputHeight = inputHeight;
|
||||
this.inputWidth = inputWidth;
|
||||
this.numChannels = numChannels;
|
||||
/**
|
||||
* Reshape to a channels x rows x columns tensor
|
||||
*
|
||||
* @param inputHeight the columns
|
||||
* @param inputWidth the rows
|
||||
* @param numChannels the channels
|
||||
*/
|
||||
@JsonCreator
|
||||
public FeedForwardToCnnPreProcessor(
|
||||
@JsonProperty("inputHeight") long inputHeight,
|
||||
@JsonProperty("inputWidth") long inputWidth,
|
||||
@JsonProperty("numChannels") long numChannels) {
|
||||
this.inputHeight = inputHeight;
|
||||
this.inputWidth = inputWidth;
|
||||
this.numChannels = numChannels;
|
||||
}
|
||||
/**
|
||||
* Reshape to a channels x rows x columns tensor
|
||||
*
|
||||
* @param inputHeight the columns
|
||||
* @param inputWidth the rows
|
||||
*/
|
||||
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
|
||||
this.inputHeight = inputHeight;
|
||||
this.inputWidth = inputWidth;
|
||||
this.numChannels = 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
this.shape = input.shape();
|
||||
if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
|
||||
|
||||
if (input.columns() != inputWidth * inputHeight * numChannels)
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid input: expect output columns must be equal to rows "
|
||||
+ inputHeight
|
||||
+ " x columns "
|
||||
+ inputWidth
|
||||
+ " x channels "
|
||||
+ numChannels
|
||||
+ " but was instead "
|
||||
+ Arrays.toString(input.shape()));
|
||||
|
||||
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
|
||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||
|
||||
return workspaceMgr.leverageTo(
|
||||
ArrayType.ACTIVATIONS,
|
||||
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
|
||||
}
|
||||
|
||||
@Override
|
||||
// return 4 dimensions
|
||||
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
|
||||
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
|
||||
|
||||
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
|
||||
if (epsilons.rank() == 2)
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); // should never happen
|
||||
|
||||
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
|
||||
}
|
||||
|
||||
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
|
||||
this.inputHeight = inputHeight;
|
||||
this.inputWidth = inputWidth;
|
||||
this.numChannels = 1;
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
|
||||
}
|
||||
|
||||
@Override
|
||||
public FeedForwardToCnnPreProcessor clone() {
|
||||
try {
|
||||
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
|
||||
if (clone.shape != null) clone.shape = clone.shape.clone();
|
||||
return clone;
|
||||
} catch (CloneNotSupportedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
this.shape = input.shape();
|
||||
if (input.rank() == 4)
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
|
||||
@Override
|
||||
public InputType getOutputType(InputType inputType) {
|
||||
|
||||
if (input.columns() != inputWidth * inputHeight * numChannels)
|
||||
throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows "
|
||||
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels
|
||||
+ " but was instead " + Arrays.toString(input.shape()));
|
||||
|
||||
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
|
||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS,
|
||||
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
|
||||
}
|
||||
|
||||
@Override
|
||||
// return 4 dimensions
|
||||
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
|
||||
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
|
||||
|
||||
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
|
||||
if (epsilons.rank() == 2)
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); //should never happen
|
||||
|
||||
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
|
||||
switch (inputType.getType()) {
|
||||
case FF:
|
||||
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
|
||||
val expSize = inputHeight * inputWidth * numChannels;
|
||||
if (c.getSize() != expSize) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input: expected FeedForward input of size "
|
||||
+ expSize
|
||||
+ " = (d="
|
||||
+ numChannels
|
||||
+ " * w="
|
||||
+ inputWidth
|
||||
+ " * h="
|
||||
+ inputHeight
|
||||
+ "), got "
|
||||
+ inputType);
|
||||
}
|
||||
return InputType.convolutional(inputHeight, inputWidth, numChannels);
|
||||
case CNN:
|
||||
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
|
||||
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public FeedForwardToCnnPreProcessor clone() {
|
||||
try {
|
||||
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
|
||||
if (clone.shape != null)
|
||||
clone.shape = clone.shape.clone();
|
||||
return clone;
|
||||
} catch (CloneNotSupportedException e) {
|
||||
throw new RuntimeException(e);
|
||||
if (c2.getChannels() != numChannels
|
||||
|| c2.getHeight() != inputHeight
|
||||
|| c2.getWidth() != inputWidth) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input: Got CNN input type with (d,w,h)=("
|
||||
+ c2.getChannels()
|
||||
+ ","
|
||||
+ c2.getWidth()
|
||||
+ ","
|
||||
+ c2.getHeight()
|
||||
+ ") but expected ("
|
||||
+ numChannels
|
||||
+ ","
|
||||
+ inputHeight
|
||||
+ ","
|
||||
+ inputWidth
|
||||
+ ")");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(InputType inputType) {
|
||||
|
||||
switch (inputType.getType()) {
|
||||
case FF:
|
||||
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
|
||||
val expSize = inputHeight * inputWidth * numChannels;
|
||||
if (c.getSize() != expSize) {
|
||||
throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
|
||||
+ " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
|
||||
+ inputType);
|
||||
}
|
||||
return InputType.convolutional(inputHeight, inputWidth, numChannels);
|
||||
case CNN:
|
||||
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
|
||||
|
||||
if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
|
||||
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
|
||||
+ "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
|
||||
+ "," + inputHeight + "," + inputWidth + ")");
|
||||
}
|
||||
return c2;
|
||||
case CNNFlat:
|
||||
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
|
||||
if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
|
||||
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
|
||||
+ "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
|
||||
+ "," + inputHeight + "," + inputWidth + ")");
|
||||
}
|
||||
return c3.getUnflattenedType();
|
||||
default:
|
||||
throw new IllegalStateException("Invalid input type: got " + inputType);
|
||||
return c2;
|
||||
case CNNFlat:
|
||||
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
|
||||
if (c3.getDepth() != numChannels
|
||||
|| c3.getHeight() != inputHeight
|
||||
|| c3.getWidth() != inputWidth) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input: Got CNN input type with (d,w,h)=("
|
||||
+ c3.getDepth()
|
||||
+ ","
|
||||
+ c3.getWidth()
|
||||
+ ","
|
||||
+ c3.getHeight()
|
||||
+ ") but expected ("
|
||||
+ numChannels
|
||||
+ ","
|
||||
+ inputHeight
|
||||
+ ","
|
||||
+ inputWidth
|
||||
+ ")");
|
||||
}
|
||||
return c3.getUnflattenedType();
|
||||
default:
|
||||
throw new IllegalStateException("Invalid input type: got " + inputType);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
|
||||
int minibatchSize) {
|
||||
//Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
|
||||
return new Pair<>(maskArray, currentMaskState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(
|
||||
INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||
// Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
|
||||
return new Pair<>(maskArray, currentMaskState);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -369,7 +369,7 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
|
|||
|
||||
protected String layerId() {
|
||||
String name = this.layerConfiguration.getName();
|
||||
return "(layer name: "
|
||||
return "(network: " + getNetConfiguration().getName() + " layer name: "
|
||||
+ (name == null ? "\"\"" : name)
|
||||
+ ", layer index: "
|
||||
+ index
|
||||
|
|
|
@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
|||
|
||||
int[] args = new int[] {
|
||||
(int)kH, (int)kW, strides[0], strides[1],
|
||||
pad[0], pad[1], dilation[0], dilation[1], sameMode,
|
||||
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
||||
pad[0], pad[1], dilation[0], dilation[1], sameMode //,
|
||||
//nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
||||
};
|
||||
|
||||
INDArray delta;
|
||||
|
@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
|||
|
||||
int[] args = new int[] {
|
||||
kH, kW, strides[0], strides[1],
|
||||
pad[0], pad[1], dilation[0], dilation[1], sameMode,
|
||||
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
||||
pad[0], pad[1], dilation[0], dilation[1], sameMode //,
|
||||
//nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
||||
};
|
||||
|
||||
//DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
|
||||
|
@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
|||
} else {
|
||||
opInputs = new INDArray[]{input, weights};
|
||||
}
|
||||
/**
|
||||
* 2D deconvolution implementation
|
||||
*
|
||||
* IntArgs:
|
||||
* 0: kernel height
|
||||
* 1: kernel width
|
||||
* 2: stride height
|
||||
* 3: stride width
|
||||
* 4: padding height
|
||||
* 5: padding width
|
||||
* 6: dilation height
|
||||
* 7: dilation width
|
||||
* 8: same mode: 0 false, 1 true
|
||||
*/
|
||||
CustomOp op = DynamicCustomOp.builder("deconv2d")
|
||||
.addInputs(opInputs)
|
||||
.addIntegerArguments(args)
|
||||
|
|
|
@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
|
|||
LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
|
||||
layers[i] =
|
||||
lc.instantiate(
|
||||
lc.getNetConfiguration(),
|
||||
this.getNetConfiguration(),
|
||||
trainingListeners,
|
||||
i,
|
||||
paramsView,
|
||||
|
|
|
@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
|
|||
|
||||
params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
|
||||
conf.getNetConfiguration().addNetWideVariable(GAMMA);
|
||||
conf.addVariable(GAMMA);
|
||||
params.put(BETA, createBeta(conf, betaView, initializeParams));
|
||||
conf.getNetConfiguration().addNetWideVariable(BETA);
|
||||
conf.addVariable(BETA);
|
||||
|
||||
meanOffset = 2 * nOut;
|
||||
}
|
||||
|
@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
|
|||
|
||||
params.put(GLOBAL_MEAN, globalMeanView);
|
||||
conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
|
||||
conf.addVariable(GLOBAL_MEAN);
|
||||
if(layer.isUseLogStd()){
|
||||
params.put(GLOBAL_LOG_STD, globalVarView);
|
||||
conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
|
||||
conf.addVariable(GLOBAL_LOG_STD);
|
||||
} else {
|
||||
params.put(GLOBAL_VAR, globalVarView);
|
||||
conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
|
||||
conf.addVariable(GLOBAL_VAR);
|
||||
}
|
||||
|
||||
return params;
|
||||
|
|
|
@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
|
|||
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
||||
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
|
||||
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
|
||||
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
|
||||
conf.addVariable(WEIGHT_KEY);
|
||||
conf.addVariable(BIAS_KEY);
|
||||
} else {
|
||||
INDArray weightView = paramsView;
|
||||
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
||||
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
|
||||
conf.addVariable(WEIGHT_KEY);
|
||||
}
|
||||
|
||||
return params;
|
||||
|
|
|
@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true
|
|||
|
||||
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
|
||||
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
|
||||
|
||||
#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true
|
||||
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
|
|
Binary file not shown.
|
@ -69,18 +69,18 @@ app_path=$0
|
|||
|
||||
# Need this for daisy-chained symlinks.
|
||||
while
|
||||
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
|
||||
[ -h "$app_path" ]
|
||||
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
|
||||
[ -h "$app_path" ]
|
||||
do
|
||||
ls=$(ls -ld "$app_path")
|
||||
link=${ls#*' -> '}
|
||||
case $link in #(
|
||||
/*) app_path=$link ;; #(
|
||||
*) app_path=$APP_HOME$link ;;
|
||||
esac
|
||||
ls=$( ls -ld "$app_path" )
|
||||
link=${ls#*' -> '}
|
||||
case $link in #(
|
||||
/*) app_path=$link ;; #(
|
||||
*) app_path=$APP_HOME$link ;;
|
||||
esac
|
||||
done
|
||||
|
||||
APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit
|
||||
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=${0##*/}
|
||||
|
@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
|||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD=maximum
|
||||
|
||||
warn() {
|
||||
echo "$*"
|
||||
warn () {
|
||||
echo "$*"
|
||||
} >&2
|
||||
|
||||
die() {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
} >&2
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
|
@ -107,52 +107,51 @@ cygwin=false
|
|||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "$(uname)" in #(
|
||||
CYGWIN*) cygwin=true ;; #(
|
||||
Darwin*) darwin=true ;; #(
|
||||
MSYS* | MINGW*) msys=true ;; #(
|
||||
NONSTOP*) nonstop=true ;;
|
||||
case "$( uname )" in #(
|
||||
CYGWIN* ) cygwin=true ;; #(
|
||||
Darwin* ) darwin=true ;; #(
|
||||
MSYS* | MINGW* ) msys=true ;; #(
|
||||
NONSTOP* ) nonstop=true ;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ]; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ]; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD=$JAVA_HOME/jre/sh/java
|
||||
else
|
||||
JAVACMD=$JAVA_HOME/bin/java
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ]; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD=$JAVA_HOME/jre/sh/java
|
||||
else
|
||||
JAVACMD=$JAVA_HOME/bin/java
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
fi
|
||||
else
|
||||
JAVACMD=java
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
JAVACMD=java
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then
|
||||
case $MAX_FD in #(
|
||||
max*)
|
||||
MAX_FD=$(ulimit -H -n) ||
|
||||
warn "Could not query maximum file descriptor limit"
|
||||
;;
|
||||
esac
|
||||
case $MAX_FD in #(
|
||||
'' | soft) : ;; #(
|
||||
*)
|
||||
ulimit -n "$MAX_FD" ||
|
||||
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||
;;
|
||||
esac
|
||||
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
||||
case $MAX_FD in #(
|
||||
max*)
|
||||
MAX_FD=$( ulimit -H -n ) ||
|
||||
warn "Could not query maximum file descriptor limit"
|
||||
esac
|
||||
case $MAX_FD in #(
|
||||
'' | soft) :;; #(
|
||||
*)
|
||||
ulimit -n "$MAX_FD" ||
|
||||
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||
esac
|
||||
fi
|
||||
|
||||
# Collect all arguments for the java command, stacking in reverse order:
|
||||
|
@ -164,36 +163,34 @@ fi
|
|||
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
|
||||
|
||||
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||
if "$cygwin" || "$msys"; then
|
||||
APP_HOME=$(cygpath --path --mixed "$APP_HOME")
|
||||
CLASSPATH=$(cygpath --path --mixed "$CLASSPATH")
|
||||
if "$cygwin" || "$msys" ; then
|
||||
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
|
||||
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
|
||||
|
||||
JAVACMD=$(cygpath --unix "$JAVACMD")
|
||||
JAVACMD=$( cygpath --unix "$JAVACMD" )
|
||||
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
for arg; do
|
||||
if
|
||||
case $arg in #(
|
||||
-*) false ;; # don't mess with options #(
|
||||
/?*)
|
||||
t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
||||
[ -e "$t" ]
|
||||
;; #(
|
||||
*) false ;;
|
||||
esac
|
||||
then
|
||||
arg=$(cygpath --path --ignore --mixed "$arg")
|
||||
fi
|
||||
# Roll the args list around exactly as many times as the number of
|
||||
# args, so each arg winds up back in the position where it started, but
|
||||
# possibly modified.
|
||||
#
|
||||
# NB: a `for` loop captures its iteration list before it begins, so
|
||||
# changing the positional parameters here affects neither the number of
|
||||
# iterations, nor the values presented in `arg`.
|
||||
shift # remove old arg
|
||||
set -- "$@" "$arg" # push replacement arg
|
||||
done
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
for arg do
|
||||
if
|
||||
case $arg in #(
|
||||
-*) false ;; # don't mess with options #(
|
||||
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
||||
[ -e "$t" ] ;; #(
|
||||
*) false ;;
|
||||
esac
|
||||
then
|
||||
arg=$( cygpath --path --ignore --mixed "$arg" )
|
||||
fi
|
||||
# Roll the args list around exactly as many times as the number of
|
||||
# args, so each arg winds up back in the position where it started, but
|
||||
# possibly modified.
|
||||
#
|
||||
# NB: a `for` loop captures its iteration list before it begins, so
|
||||
# changing the positional parameters here affects neither the number of
|
||||
# iterations, nor the values presented in `arg`.
|
||||
shift # remove old arg
|
||||
set -- "$@" "$arg" # push replacement arg
|
||||
done
|
||||
fi
|
||||
|
||||
# Collect all arguments for the java command;
|
||||
|
@ -203,15 +200,10 @@ fi
|
|||
# * put everything else in single quotes, so that it's not re-expanded.
|
||||
|
||||
set -- \
|
||||
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
||||
-classpath "$CLASSPATH" \
|
||||
org.gradle.wrapper.GradleWrapperMain \
|
||||
"$@"
|
||||
|
||||
# Stop when "xargs" is not available.
|
||||
if ! command -v xargs >/dev/null 2>&1; then
|
||||
die "xargs is not available"
|
||||
fi
|
||||
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
||||
-classpath "$CLASSPATH" \
|
||||
org.gradle.wrapper.GradleWrapperMain \
|
||||
"$@"
|
||||
|
||||
# Use "xargs" to parse quoted args.
|
||||
#
|
||||
|
@ -233,10 +225,10 @@ fi
|
|||
#
|
||||
|
||||
eval "set -- $(
|
||||
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
|
||||
xargs -n1 |
|
||||
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
|
||||
tr '\n' ' '
|
||||
)" '"$@"'
|
||||
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
|
||||
xargs -n1 |
|
||||
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
|
||||
tr '\n' ' '
|
||||
)" '"$@"'
|
||||
|
||||
exec "$JAVACMD" "$@"
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
@rem limitations under the License.
|
||||
@rem
|
||||
|
||||
@if "%DEBUG%"=="" @echo off
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
|
@ -25,7 +25,7 @@
|
|||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%"=="" set DIRNAME=.
|
||||
if "%DIRNAME%" == "" set DIRNAME=.
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
|
@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
|
|||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if %ERRORLEVEL% equ 0 goto execute
|
||||
if "%ERRORLEVEL%" == "0" goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
|||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if %ERRORLEVEL% equ 0 goto mainEnd
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
set EXIT_CODE=%ERRORLEVEL%
|
||||
if %EXIT_CODE% equ 0 set EXIT_CODE=1
|
||||
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
|
||||
exit /b %EXIT_CODE%
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
|
Loading…
Reference in New Issue