parent
3ea555b645
commit
1c3496ad84
|
@ -1,473 +1,261 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package net.brutex.gan;
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import static net.brutex.ai.dnn.api.NN.dense;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import javax.swing.*;
|
||||||
import java.util.UUID;
|
|
||||||
import javax.imageio.ImageIO;
|
|
||||||
import javax.swing.ImageIcon;
|
|
||||||
import javax.swing.JFrame;
|
|
||||||
import javax.swing.JLabel;
|
|
||||||
import javax.swing.JPanel;
|
|
||||||
import javax.swing.WindowConstants;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.datavec.api.split.FileSplit;
|
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
|
||||||
import org.datavec.image.transform.ColorConversionTransform;
|
|
||||||
import org.datavec.image.transform.ImageTransform;
|
|
||||||
import org.datavec.image.transform.PipelineImageTransform;
|
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
|
||||||
import org.datavec.image.transform.ShowImageTransform;
|
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
|
||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class App {
|
public class App {
|
||||||
private static final double LEARNING_RATE = 0.000002;
|
private static final double LEARNING_RATE = 0.002;
|
||||||
private static final double GRADIENT_THRESHOLD = 100.0;
|
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||||
|
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||||
|
private static final int BATCHSIZE = 128;
|
||||||
|
private static JFrame frame;
|
||||||
|
private static JPanel panel;
|
||||||
|
|
||||||
private static final int X_DIM = 20 ;
|
private static LayerConfiguration[] genLayers() {
|
||||||
private static final int Y_DIM = 20;
|
return new LayerConfiguration[] {
|
||||||
private static final int CHANNELS = 1;
|
dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||||
private static final int batchSize = 1;
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
private static final int INPUT = 10;
|
dense().nIn(256).nOut(512).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(512).nOut(1024).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private static final int OUTPUT_PER_PANEL = 16;
|
/**
|
||||||
|
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
|
||||||
|
*
|
||||||
|
* @return config
|
||||||
|
*/
|
||||||
|
private static NeuralNetConfiguration generator() {
|
||||||
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(genLayers())
|
||||||
|
.name("generator")
|
||||||
|
.build();
|
||||||
|
|
||||||
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
|
return conf;
|
||||||
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
}
|
||||||
|
|
||||||
private static JFrame frame;
|
private static LayerConfiguration[] disLayers() {
|
||||||
private static JFrame frame2;
|
return new LayerConfiguration[]{
|
||||||
private static JPanel panel;
|
dense().nIn(784).nOut(1024).build(),
|
||||||
private static JPanel panel2;
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(1024).nOut(512).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(512).nOut(256).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private static final String OUTPUT_DIR = "C:/temp/output/";
|
private static NeuralNetConfiguration discriminator() {
|
||||||
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(disLayers())
|
||||||
|
.name("discriminator")
|
||||||
|
.build();
|
||||||
|
|
||||||
private static LayerConfiguration[] genLayers() {
|
return conf;
|
||||||
return new LayerConfiguration[] {
|
}
|
||||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
|
||||||
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
|
||||||
|
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
private static NeuralNetConfiguration gan() {
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
LayerConfiguration[] genLayers = genLayers();
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
.map((layer) -> {
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
|
return FrozenLayerWithBackprop.builder(layer).build();
|
||||||
|
} else {
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
}).toArray(LayerConfiguration[]::new);
|
||||||
|
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
||||||
|
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
};
|
.seed(42)
|
||||||
}
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(layers)
|
||||||
|
.name("GAN")
|
||||||
|
.build();
|
||||||
|
|
||||||
/**
|
return conf;
|
||||||
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
|
}
|
||||||
*
|
|
||||||
* @return config
|
|
||||||
*/
|
|
||||||
private static NeuralNetConfiguration generator() {
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
|
||||||
.seed(42)
|
|
||||||
.updater(UPDATER)
|
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
|
||||||
//.weightInit(WeightInit.XAVIER)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.activation(Activation.IDENTITY)
|
|
||||||
.layersFromArray(genLayers())
|
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
|
||||||
.build();
|
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
|
|
||||||
return conf;
|
@Test
|
||||||
}
|
public void runTest() throws Exception {
|
||||||
|
App.main(null);
|
||||||
|
}
|
||||||
|
public static void main(String... args) throws Exception {
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
private static LayerConfiguration[] disLayers() {
|
MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
return new LayerConfiguration[]{
|
|
||||||
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
|
|
||||||
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
|
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
||||||
};
|
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
||||||
}
|
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
||||||
|
gen.init();
|
||||||
|
dis.init();
|
||||||
|
gan.init();
|
||||||
|
|
||||||
private static NeuralNetConfiguration discriminator() {
|
copyParams(gen, dis, gan);
|
||||||
|
|
||||||
NeuralNetConfiguration conf =
|
gen.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
NeuralNetConfiguration.builder()
|
dis.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
.seed(42)
|
gan.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
.updater(UPDATER)
|
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
|
||||||
.weightNoise(null)
|
|
||||||
// .weightInitFn(new WeightInitXavier())
|
|
||||||
// .activationFn(new ActivationIdentity())
|
|
||||||
.activation(Activation.IDENTITY)
|
|
||||||
.layersFromArray(disLayers())
|
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
.build();
|
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
|
|
||||||
return conf;
|
trainData.reset();
|
||||||
}
|
|
||||||
|
|
||||||
private static NeuralNetConfiguration gan() {
|
int j = 0;
|
||||||
LayerConfiguration[] genLayers = genLayers();
|
for (int i = 0; i < 50; i++) {
|
||||||
LayerConfiguration[] disLayers = Arrays.stream(disLayers())
|
while (trainData.hasNext()) {
|
||||||
.map((layer) -> {
|
j++;
|
||||||
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
|
||||||
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
|
||||||
} else {
|
|
||||||
return layer;
|
|
||||||
}
|
|
||||||
}).toArray(LayerConfiguration[]::new);
|
|
||||||
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
// generate data
|
||||||
.seed(42)
|
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
|
||||||
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
|
int batchSize = (int) real.shape()[0];
|
||||||
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
|
||||||
.gradientNormalizationThreshold( 100 )
|
INDArray fakeIn = Nd4j.rand(batchSize, 100);
|
||||||
//.weightInitFn( new WeightInitXavier() ) //this is internal
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
|
||||||
.weightInit( WeightInit.XAVIER)
|
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
||||||
//.activationFn( new ActivationIdentity()) //this is internal
|
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
||||||
.activation( Activation.IDENTITY )
|
|
||||||
.layersFromArray( layers )
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
.dataType(DataType.FLOAT)
|
dis.fit(data);
|
||||||
.build();
|
dis.fit(data);
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
return conf;
|
// Update the discriminator in the GAN network
|
||||||
}
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
|
||||||
|
|
||||||
|
|
||||||
@Test
|
if (j % 10 == 1) {
|
||||||
public void runTest() throws Exception {
|
System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
|
||||||
if(! log.isDebugEnabled()) {
|
INDArray[] samples = new INDArray[9];
|
||||||
log.info("Logging is not set to DEBUG");
|
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
||||||
}
|
|
||||||
else {
|
|
||||||
log.info("Logging is set to DEBUG");
|
|
||||||
}
|
|
||||||
main();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String... args) throws Exception {
|
for (int k = 0; k < 9; k++) {
|
||||||
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
//samples[k] = gen.output(input, false);
|
||||||
|
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
|
|
||||||
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
|
}
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(500);
|
visualize(samples);
|
||||||
|
}
|
||||||
//MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
|
}
|
||||||
//FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
|
trainData.reset();
|
||||||
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
|
// Copy the GANs generator to gen.
|
||||||
|
//updateGen(gen, gan);
|
||||||
|
|
||||||
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
|
||||||
|
|
||||||
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
|
||||||
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
|
|
||||||
|
|
||||||
ImageTransform tr = new PipelineImageTransform.Builder()
|
|
||||||
//.addImageTransform(transform) //convert to GREY SCALE
|
|
||||||
.addImageTransform(transform3)
|
|
||||||
//.addImageTransform(transform2)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
|
|
||||||
imageRecordReader.initialize(fileSplit, tr);
|
|
||||||
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
|
|
||||||
|
|
||||||
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
|
||||||
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
|
||||||
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
|
||||||
gen.init(); log.debug("Generator network: {}", gen);
|
|
||||||
dis.init(); log.debug("Discriminator network: {}", dis);
|
|
||||||
gan.init(); log.info("Complete GAN network: {}", gan);
|
|
||||||
|
|
||||||
|
|
||||||
copyParams(gen, dis, gan);
|
|
||||||
|
|
||||||
//gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
|
|
||||||
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
|
||||||
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
|
||||||
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
|
||||||
//dis.setListeners(new ScoreToChartListener("dis"));
|
|
||||||
|
|
||||||
//System.out.println(gan.toString());
|
|
||||||
//gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
|
||||||
|
|
||||||
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
|
|
||||||
//trainData.reset();
|
|
||||||
|
|
||||||
int j = 0;
|
|
||||||
for (int i = 0; i < 51; i++) { //epoch
|
|
||||||
while (trainData.hasNext()) {
|
|
||||||
j++;
|
|
||||||
|
|
||||||
DataSet next = trainData.next();
|
|
||||||
// generate data
|
|
||||||
INDArray real = next.getFeatures();//.div(255f);
|
|
||||||
|
|
||||||
//start next round if there are not enough images left to have a full batchsize dataset
|
|
||||||
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
|
|
||||||
log.warn("Your total number of input images is not a multiple of {}, "
|
|
||||||
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//if(i%20 == 0) {
|
|
||||||
frame2 = visualize(new INDArray[]{real}, batchSize,
|
|
||||||
frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
|
||||||
//}
|
|
||||||
real.divi(255f);
|
|
||||||
|
|
||||||
// int batchSize = (int) real.shape()[0];
|
|
||||||
|
|
||||||
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
|
|
||||||
|
|
||||||
|
|
||||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
|
||||||
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
|
|
||||||
//log.info("real has {} items.", real.length());
|
|
||||||
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
|
||||||
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
|
||||||
|
|
||||||
|
|
||||||
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
|
||||||
|
|
||||||
dis.fit(data);
|
|
||||||
//dis.fit(data);
|
|
||||||
|
|
||||||
// Update the discriminator in the GAN network
|
|
||||||
updateGan(gen, dis, gan);
|
|
||||||
|
|
||||||
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
|
|
||||||
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
|
|
||||||
|
|
||||||
//Visualize and reporting
|
|
||||||
if (j % 10 == 1) {
|
|
||||||
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
|
|
||||||
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
|
|
||||||
|
|
||||||
|
|
||||||
for (int k = 0; k < samples.length; k++) {
|
|
||||||
//INDArray input = fakeSet2.get(k).getFeatures();
|
|
||||||
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
|
||||||
INDArray input = fakeSet2.get(k).getFeatures();
|
|
||||||
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
|
|
||||||
|
|
||||||
//samples[k] = gen.output(input, false);
|
|
||||||
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
|
||||||
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
//samples[k] =
|
|
||||||
samples[k].addi(1f).divi(2f).muli(255f);
|
|
||||||
|
|
||||||
}
|
|
||||||
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (trainData.resetSupported()) {
|
|
||||||
trainData.reset();
|
|
||||||
} else {
|
|
||||||
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
|
||||||
}
|
|
||||||
// Copy the GANs generator to gen.
|
// Copy the GANs generator to gen.
|
||||||
updateGen(gen, gan);
|
updateGen(gen, gan);
|
||||||
|
|
||||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||||
}
|
if (i < genLayerCount) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
} else {
|
||||||
int genLayerCount = gen.getLayers().length;
|
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
||||||
for (int i = 0; i < gan.getLayers().length; i++) {
|
}
|
||||||
if (i < genLayerCount) {
|
|
||||||
if(gan.getLayer(i).getParams() != null)
|
|
||||||
gan.getLayer(i).setParams(gen.getLayer(i).getParams());
|
|
||||||
} else {
|
|
||||||
if(gan.getLayer(i).getParams() != null)
|
|
||||||
gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
|
||||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
|
||||||
int genLayerCount = gen.getLayers().length;
|
|
||||||
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
|
||||||
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
|
||||||
if (isOrig) {
|
|
||||||
frame.setTitle("Viz Original");
|
|
||||||
} else {
|
|
||||||
frame.setTitle("Generated");
|
|
||||||
}
|
|
||||||
|
|
||||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
|
||||||
frame.setLayout(new BorderLayout());
|
|
||||||
|
|
||||||
JPanel panelx = new JPanel();
|
|
||||||
|
|
||||||
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
|
||||||
for (INDArray sample : samples) {
|
|
||||||
for(int i = 0; i<batchElements; i++) {
|
|
||||||
panelx.add(getImage(sample, i, isOrig));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
frame.add(panelx, BorderLayout.CENTER);
|
|
||||||
frame.setVisible(true);
|
|
||||||
|
|
||||||
frame.revalidate();
|
|
||||||
frame.setMinimumSize(new Dimension(300, 20));
|
|
||||||
frame.pack();
|
|
||||||
return frame;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
|
||||||
final BufferedImage bi;
|
|
||||||
if(CHANNELS>1) {
|
|
||||||
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
|
||||||
} else {
|
|
||||||
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
|
||||||
}
|
|
||||||
final int imageSize = X_DIM * Y_DIM;
|
|
||||||
final int offset = batchElement * imageSize;
|
|
||||||
int pxl = offset * CHANNELS; //where to start in the INDArray
|
|
||||||
|
|
||||||
//Image in NCHW - channels first format
|
|
||||||
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
|
||||||
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
|
|
||||||
for (int x = 0; x < X_DIM; x++) { //step through the rows y
|
|
||||||
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
|
|
||||||
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
|
|
||||||
pxl++; //next item in INDArray
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ImageIcon orig = new ImageIcon(bi);
|
|
||||||
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
|
||||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
|
||||||
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
|
||||||
return new JLabel(scaled);
|
|
||||||
|
|
||||||
}
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
|
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
private static void visualize(INDArray[] samples) {
|
||||||
// Save the images to disk
|
if (frame == null) {
|
||||||
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
frame = new JFrame();
|
||||||
|
frame.setTitle("Viz");
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
log.debug("Images saved successfully.");
|
panel = new JPanel();
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("Error saving the images: {}", e.getMessage());
|
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
|
||||||
|
frame.add(panel, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
panel.removeAll();
|
||||||
|
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
panel.add(getImage(sample));
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.pack();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor) {
|
||||||
|
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
for (int i = 0; i < 784; i++) {
|
||||||
|
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
|
||||||
|
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
||||||
|
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
|
||||||
|
return new JLabel(scaled);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
|
||||||
File directory = new File(outputDirectory);
|
|
||||||
if (!directory.exists()) {
|
|
||||||
directory.mkdir();
|
|
||||||
}
|
|
||||||
|
|
||||||
File outputFile = new File(directory, fileName);
|
|
||||||
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static BufferedImage imageToBufferedImage(Image image) {
|
|
||||||
if (image instanceof BufferedImage) {
|
|
||||||
return (BufferedImage) image;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a buffered image with the same dimensions and transparency as the original image
|
|
||||||
BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
|
||||||
|
|
||||||
// Draw the original image onto the buffered image
|
|
||||||
Graphics2D g2d = bufferedImage.createGraphics();
|
|
||||||
g2d.drawImage(image, 0, 0, null);
|
|
||||||
g2d.dispose();
|
|
||||||
|
|
||||||
return bufferedImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,343 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.List;
|
||||||
|
import javax.imageio.ImageIO;
|
||||||
|
import javax.swing.*;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
|
import org.datavec.image.transform.*;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class App2 {
|
||||||
|
|
||||||
|
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
||||||
|
static final float COLORSPACE = 255f;
|
||||||
|
static final int DIMENSIONS = 28;
|
||||||
|
static final int CHANNELS = 1;
|
||||||
|
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
||||||
|
final int OUTPUT_PER_PANEL = 10;
|
||||||
|
|
||||||
|
final boolean BIAS = true;
|
||||||
|
|
||||||
|
static final int BATCHSIZE=128;
|
||||||
|
|
||||||
|
private JFrame frame2, frame;
|
||||||
|
static final String OUTPUT_DIR = "d:/out/";
|
||||||
|
|
||||||
|
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
|
||||||
|
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void runTest() throws IOException {
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
|
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
||||||
|
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
|
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||||
|
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||||
|
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
||||||
|
|
||||||
|
ImageTransform tr = new PipelineImageTransform.Builder()
|
||||||
|
.addImageTransform(transform) //convert to GREY SCALE
|
||||||
|
.addImageTransform(transform3)
|
||||||
|
//.addImageTransform(transform2)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
|
||||||
|
imageRecordReader.initialize(fileSplit, tr);
|
||||||
|
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
|
||||||
|
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
|
|
||||||
|
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
|
||||||
|
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
|
||||||
|
|
||||||
|
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
|
||||||
|
.map((layer) -> {
|
||||||
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
|
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
||||||
|
} else {
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
}).toArray(LayerConfiguration[]::new);
|
||||||
|
|
||||||
|
NeuralNetConfiguration netConfiguration =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("GAN")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.updater(App2Config.UPDATER)
|
||||||
|
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
|
||||||
|
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
|
||||||
|
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.dataType(DataType.FLOAT)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
|
||||||
|
|
||||||
|
dis.init(); log.debug("Discriminator network: {}", dis);
|
||||||
|
gen.init(); log.debug("Generator network: {}", gen);
|
||||||
|
gan.init(); log.debug("GAN network: {}", gan);
|
||||||
|
|
||||||
|
|
||||||
|
log.info("Generator Summary:\n{}", gen.summary());
|
||||||
|
log.info("GAN Summary:\n{}", gan.summary());
|
||||||
|
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||||
|
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
|
||||||
|
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
||||||
|
|
||||||
|
int j = 0;
|
||||||
|
for (int i = 0; i < 51; i++) { //epoch
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
j++;
|
||||||
|
DataSet next = trainData.next();
|
||||||
|
// generate data
|
||||||
|
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
|
||||||
|
|
||||||
|
//start next round if there are not enough images left to have a full batchsize dataset
|
||||||
|
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
|
||||||
|
log.warn("Your total number of input images is not a multiple of {}, "
|
||||||
|
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
//if(i%20 == 0) {
|
||||||
|
|
||||||
|
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
|
||||||
|
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
||||||
|
//}
|
||||||
|
//real.divi(255f);
|
||||||
|
|
||||||
|
// int batchSize = (int) real.shape()[0];
|
||||||
|
|
||||||
|
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||||
|
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
|
||||||
|
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
|
||||||
|
|
||||||
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
|
// when generator has TANH as activation - value range is -1 to 1
|
||||||
|
// when generator has SIGMOID, then range is 0 to 1
|
||||||
|
fake.addi(1f).divi(2f);
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, label_real);
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
|
||||||
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
|
dis.fit(data);
|
||||||
|
dis.fit(data);
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
|
||||||
|
|
||||||
|
//Visualize and reporting
|
||||||
|
if (j % 10 == 1) {
|
||||||
|
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
|
||||||
|
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
|
||||||
|
|
||||||
|
|
||||||
|
for (int k = 0; k < samples.length; k++) {
|
||||||
|
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
|
||||||
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
|
||||||
|
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
|
||||||
|
input = input.reshape(1, App2Config.INPUT);
|
||||||
|
|
||||||
|
//samples[k] = gen.output(input, false);
|
||||||
|
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
|
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||||
|
//samples[k] =
|
||||||
|
//samples[k].muli(255f);
|
||||||
|
|
||||||
|
}
|
||||||
|
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (trainData.resetSupported()) {
|
||||||
|
trainData.reset();
|
||||||
|
} else {
|
||||||
|
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||||
|
}
|
||||||
|
// Copy the GANs generator to gen.
|
||||||
|
updateGen(gen, gan);
|
||||||
|
log.info("Updated GAN's generator from gen.");
|
||||||
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
||||||
|
if (isOrig) {
|
||||||
|
frame.setTitle("Viz Original");
|
||||||
|
} else {
|
||||||
|
frame.setTitle("Generated");
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
|
JPanel panelx = new JPanel();
|
||||||
|
|
||||||
|
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
for(int i = 0; i<batchElements; i++) {
|
||||||
|
panelx.add(getImage(sample, i, isOrig));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
frame.add(panelx, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.setMinimumSize(new Dimension(300, 20));
|
||||||
|
frame.pack();
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||||
|
final BufferedImage bi;
|
||||||
|
if(CHANNELS >1) {
|
||||||
|
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
||||||
|
} else {
|
||||||
|
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
||||||
|
}
|
||||||
|
final int imageSize = DIMENSIONS * DIMENSIONS;
|
||||||
|
final int offset = batchElement * imageSize;
|
||||||
|
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||||
|
|
||||||
|
//Image in NCHW - channels first format
|
||||||
|
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||||
|
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
|
||||||
|
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
|
||||||
|
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
||||||
|
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
|
||||||
|
bi.getRaster().setSample(x, y, c, f_pxl);
|
||||||
|
pxl++; //next item in INDArray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||||
|
return new JLabel(scaled);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||||
|
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Save the images to disk
|
||||||
|
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||||
|
|
||||||
|
log.debug("Images saved successfully.");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("Error saving the images: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||||
|
File directory = new File(outputDirectory);
|
||||||
|
if (!directory.exists()) {
|
||||||
|
directory.mkdir();
|
||||||
|
}
|
||||||
|
|
||||||
|
File outputFile = new File(directory, fileName);
|
||||||
|
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage imageToBufferedImage(Image image) {
|
||||||
|
if (image instanceof BufferedImage) {
|
||||||
|
return (BufferedImage) image;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a buffered image with the same dimensions and transparency as the original image
|
||||||
|
BufferedImage bufferedImage;
|
||||||
|
if (CHANNELS > 1) {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||||
|
} else {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw the original image onto the buffered image
|
||||||
|
Graphics2D g2d = bufferedImage.createGraphics();
|
||||||
|
g2d.drawImage(image, 0, 0, null);
|
||||||
|
g2d.dispose();
|
||||||
|
|
||||||
|
return bufferedImage;
|
||||||
|
}
|
||||||
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
|
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,176 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import static net.brutex.ai.dnn.api.NN.*;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
public class App2Config {
|
||||||
|
|
||||||
|
public static final int INPUT = 100;
|
||||||
|
public static final int X_DIM = 28;
|
||||||
|
public static final int y_DIM = 28;
|
||||||
|
public static final int CHANNELS = 1;
|
||||||
|
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
||||||
|
|
||||||
|
static LayerConfiguration[] genLayerConfig() {
|
||||||
|
return new LayerConfiguration[] {
|
||||||
|
/*
|
||||||
|
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(), /*
|
||||||
|
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(1,1)
|
||||||
|
.padding(0,0)
|
||||||
|
.convolutionMode(ConvolutionMode.Truncate)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.hasBias(BIAS).build(),
|
||||||
|
//BatchNormalization.builder().nOut(CHANNELS).build(),
|
||||||
|
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(2,2)
|
||||||
|
.padding(0,0)
|
||||||
|
.convolutionMode(ConvolutionMode.Truncate)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.hasBias(BIAS).build(),
|
||||||
|
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
|
||||||
|
|
||||||
|
|
||||||
|
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
// DropoutLayer.builder(0.001).build(),
|
||||||
|
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
|
||||||
|
|
||||||
|
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(256).nOut(512).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(512).nOut(1024).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static LayerConfiguration[] disLayerConfig() {
|
||||||
|
return new LayerConfiguration[] {/*
|
||||||
|
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
|
||||||
|
.build(),
|
||||||
|
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
|
||||||
|
.build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
|
||||||
|
BatchNormalization.builder().build(),
|
||||||
|
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SIGMOID)
|
||||||
|
.build()
|
||||||
|
|
||||||
|
|
||||||
|
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
|
||||||
|
.activation(Activation.SIGMOID)
|
||||||
|
.build() */
|
||||||
|
dense().nIn(784).nOut(1024).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(1024).nOut(512).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(512).nOut(256).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static NeuralNetConfiguration generator() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("generator")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
|
.weightNoise(null)
|
||||||
|
// .weightInitFn(new WeightInitXavier())
|
||||||
|
// .activationFn(new ActivationIdentity())
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(App2Config.genLayerConfig())
|
||||||
|
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||||
|
|
||||||
|
.build();
|
||||||
|
conf.init();
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NeuralNetConfiguration discriminator() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("discriminator")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
|
.weightNoise(null)
|
||||||
|
// .weightInitFn(new WeightInitXavier())
|
||||||
|
// .activationFn(new ActivationIdentity())
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(disLayerConfig())
|
||||||
|
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.dataType(DataType.FLOAT)
|
||||||
|
.build();
|
||||||
|
conf.init();
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
}
|
|
@ -43,223 +43,255 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.impo
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class KerasSequentialModel extends KerasModel {
|
public class KerasSequentialModel extends KerasModel {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* (Recommended) Builder-pattern constructor for Sequential model.
|
* (Recommended) Builder-pattern constructor for Sequential model.
|
||||||
*
|
*
|
||||||
* @param modelBuilder builder object
|
* @param modelBuilder builder object
|
||||||
* @throws IOException I/O exception
|
* @throws IOException I/O exception
|
||||||
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
||||||
*/
|
*/
|
||||||
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
||||||
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
|
throws UnsupportedKerasConfigurationException,
|
||||||
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
|
IOException,
|
||||||
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
|
InvalidKerasConfigurationException {
|
||||||
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
|
this(
|
||||||
|
modelBuilder.getModelJson(),
|
||||||
|
modelBuilder.getModelYaml(),
|
||||||
|
modelBuilder.getWeightsArchive(),
|
||||||
|
modelBuilder.getWeightsRoot(),
|
||||||
|
modelBuilder.getTrainingJson(),
|
||||||
|
modelBuilder.getTrainingArchive(),
|
||||||
|
modelBuilder.isEnforceTrainingConfig(),
|
||||||
|
modelBuilder.getInputShape());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
|
||||||
|
* training configuration (JSON), weights, and "training mode" boolean indicator. When built in
|
||||||
|
* training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
|
||||||
|
* Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
|
||||||
|
* otherwise ignored.
|
||||||
|
*
|
||||||
|
* @param modelJson model configuration JSON string
|
||||||
|
* @param modelYaml model configuration YAML string
|
||||||
|
* @param trainingJson training configuration JSON string
|
||||||
|
* @throws IOException I/O exception
|
||||||
|
*/
|
||||||
|
public KerasSequentialModel(
|
||||||
|
String modelJson,
|
||||||
|
String modelYaml,
|
||||||
|
Hdf5Archive weightsArchive,
|
||||||
|
String weightsRoot,
|
||||||
|
String trainingJson,
|
||||||
|
Hdf5Archive trainingArchive,
|
||||||
|
boolean enforceTrainingConfig,
|
||||||
|
int[] inputShape)
|
||||||
|
throws IOException,
|
||||||
|
InvalidKerasConfigurationException,
|
||||||
|
UnsupportedKerasConfigurationException {
|
||||||
|
|
||||||
|
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
||||||
|
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
||||||
|
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
|
||||||
|
this.enforceTrainingConfig = enforceTrainingConfig;
|
||||||
|
|
||||||
|
/* Determine model configuration type. */
|
||||||
|
if (!modelConfig.containsKey(config.getFieldClassName()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Could not determine Keras model class (no "
|
||||||
|
+ config.getFieldClassName()
|
||||||
|
+ " field found)");
|
||||||
|
this.className = (String) modelConfig.get(config.getFieldClassName());
|
||||||
|
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Model class name must be "
|
||||||
|
+ config.getFieldClassNameSequential()
|
||||||
|
+ " (found "
|
||||||
|
+ this.className
|
||||||
|
+ ")");
|
||||||
|
|
||||||
|
/* Process layer configurations. */
|
||||||
|
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Could not find layer configurations (no "
|
||||||
|
+ config.getModelFieldConfig()
|
||||||
|
+ " field found)");
|
||||||
|
|
||||||
|
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
|
||||||
|
// For consistency
|
||||||
|
// "config" is now an object containing a "name" and "layers", the latter contain the same data
|
||||||
|
// as before.
|
||||||
|
// This change only affects Sequential models.
|
||||||
|
List<Object> layerList;
|
||||||
|
try {
|
||||||
|
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
|
||||||
|
} catch (Exception e) {
|
||||||
|
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
|
||||||
|
layerList = (List<Object>) layerMap.get("layers");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
|
||||||
* (Not recommended) Constructor for Sequential model from model configuration
|
this.layers = layerPair.getFirst();
|
||||||
* (JSON or YAML), training configuration (JSON), weights, and "training mode"
|
this.layersOrdered = layerPair.getSecond();
|
||||||
* boolean indicator. When built in training mode, certain unsupported configurations
|
|
||||||
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
|
|
||||||
* will generate warnings but will be otherwise ignored.
|
|
||||||
*
|
|
||||||
* @param modelJson model configuration JSON string
|
|
||||||
* @param modelYaml model configuration YAML string
|
|
||||||
* @param trainingJson training configuration JSON string
|
|
||||||
* @throws IOException I/O exception
|
|
||||||
*/
|
|
||||||
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
|
|
||||||
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
|
|
||||||
int[] inputShape)
|
|
||||||
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
|
||||||
|
|
||||||
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
KerasLayer inputLayer;
|
||||||
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
if (this.layersOrdered.get(0) instanceof KerasInput) {
|
||||||
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
|
inputLayer = this.layersOrdered.get(0);
|
||||||
this.enforceTrainingConfig = enforceTrainingConfig;
|
} else {
|
||||||
|
/* Add placeholder input layer and update lists of input and output layers. */
|
||||||
|
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
||||||
|
Preconditions.checkState(
|
||||||
|
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
|
||||||
|
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
||||||
|
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
||||||
|
this.layers.put(inputLayer.getName(), inputLayer);
|
||||||
|
this.layersOrdered.add(0, inputLayer);
|
||||||
|
}
|
||||||
|
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
||||||
|
this.outputLayerNames =
|
||||||
|
new ArrayList<>(
|
||||||
|
Collections.singletonList(
|
||||||
|
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
||||||
|
|
||||||
/* Determine model configuration type. */
|
/* Update each layer's inbound layer list to include (only) previous layer. */
|
||||||
if (!modelConfig.containsKey(config.getFieldClassName()))
|
KerasLayer prevLayer = null;
|
||||||
throw new InvalidKerasConfigurationException(
|
for (KerasLayer layer : this.layersOrdered) {
|
||||||
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
|
if (prevLayer != null)
|
||||||
this.className = (String) modelConfig.get(config.getFieldClassName());
|
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
|
||||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
prevLayer = layer;
|
||||||
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
|
|
||||||
+ " (found " + this.className + ")");
|
|
||||||
|
|
||||||
/* Process layer configurations. */
|
|
||||||
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
|
|
||||||
|
|
||||||
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
|
|
||||||
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
|
|
||||||
// This change only affects Sequential models.
|
|
||||||
List<Object> layerList;
|
|
||||||
try {
|
|
||||||
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
|
|
||||||
} catch (Exception e) {
|
|
||||||
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
|
|
||||||
layerList = (List<Object>) layerMap.get("layers");
|
|
||||||
}
|
|
||||||
|
|
||||||
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
|
|
||||||
prepareLayers(layerList);
|
|
||||||
this.layers = layerPair.getFirst();
|
|
||||||
this.layersOrdered = layerPair.getSecond();
|
|
||||||
|
|
||||||
KerasLayer inputLayer;
|
|
||||||
if (this.layersOrdered.get(0) instanceof KerasInput) {
|
|
||||||
inputLayer = this.layersOrdered.get(0);
|
|
||||||
} else {
|
|
||||||
/* Add placeholder input layer and update lists of input and output layers. */
|
|
||||||
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
|
||||||
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
|
|
||||||
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
|
||||||
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
|
||||||
this.layers.put(inputLayer.getName(), inputLayer);
|
|
||||||
this.layersOrdered.add(0, inputLayer);
|
|
||||||
}
|
|
||||||
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
|
||||||
this.outputLayerNames = new ArrayList<>(
|
|
||||||
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
|
||||||
|
|
||||||
/* Update each layer's inbound layer list to include (only) previous layer. */
|
|
||||||
KerasLayer prevLayer = null;
|
|
||||||
for (KerasLayer layer : this.layersOrdered) {
|
|
||||||
if (prevLayer != null)
|
|
||||||
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
|
|
||||||
prevLayer = layer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Import training configuration. */
|
|
||||||
if (enforceTrainingConfig) {
|
|
||||||
if (trainingJson != null)
|
|
||||||
importTrainingConfiguration(trainingJson);
|
|
||||||
else log.warn("If enforceTrainingConfig is true, a training " +
|
|
||||||
"configuration object has to be provided. Usually the only practical way to do this is to store" +
|
|
||||||
" your keras model with `model.save('model_path.h5'. If you store model config and weights" +
|
|
||||||
" separately no training configuration is attached.");
|
|
||||||
}
|
|
||||||
|
|
||||||
this.outputTypes = inferOutputTypes(inputShape);
|
|
||||||
|
|
||||||
if (weightsArchive != null)
|
|
||||||
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/* Import training configuration. */
|
||||||
* Default constructor
|
if (enforceTrainingConfig) {
|
||||||
*/
|
if (trainingJson != null) importTrainingConfiguration(trainingJson);
|
||||||
public KerasSequentialModel() {
|
else
|
||||||
super();
|
log.warn(
|
||||||
|
"If enforceTrainingConfig is true, a training "
|
||||||
|
+ "configuration object has to be provided. Usually the only practical way to do this is to store"
|
||||||
|
+ " your keras model with `model.save('model_path.h5'. If you store model config and weights"
|
||||||
|
+ " separately no training configuration is attached.");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
this.outputTypes = inferOutputTypes(inputShape);
|
||||||
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
|
|
||||||
*
|
|
||||||
* @return NeuralNetConfiguration
|
|
||||||
*/
|
|
||||||
public NeuralNetConfiguration getNeuralNetConfiguration()
|
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
|
||||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
|
|
||||||
if (this.inputLayerNames.size() != 1)
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
|
|
||||||
if (this.outputLayerNames.size() != 1)
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
|
||||||
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
|
if (weightsArchive != null)
|
||||||
|
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
||||||
|
}
|
||||||
|
|
||||||
if (optimizer != null) {
|
/** Default constructor */
|
||||||
modelBuilder.updater(optimizer);
|
public KerasSequentialModel() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
|
||||||
|
*
|
||||||
|
* @return NeuralNetConfiguration
|
||||||
|
*/
|
||||||
|
public NeuralNetConfiguration getNeuralNetConfiguration()
|
||||||
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
|
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
|
||||||
|
if (this.inputLayerNames.size() != 1)
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
|
||||||
|
if (this.outputLayerNames.size() != 1)
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
||||||
|
|
||||||
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
|
||||||
|
NeuralNetConfiguration.builder();
|
||||||
|
|
||||||
|
if (optimizer != null) {
|
||||||
|
modelBuilder.updater(optimizer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't forcibly override for keras import
|
||||||
|
modelBuilder.overrideNinUponBuild(false);
|
||||||
|
/* Add layers one at a time. */
|
||||||
|
KerasLayer prevLayer = null;
|
||||||
|
int layerIndex = 0;
|
||||||
|
for (KerasLayer layer : this.layersOrdered) {
|
||||||
|
if (layer.isLayer()) {
|
||||||
|
int nbInbound = layer.getInboundLayerNames().size();
|
||||||
|
if (nbInbound != 1)
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
||||||
|
+ nbInbound
|
||||||
|
+ " for layer "
|
||||||
|
+ layer.getName()
|
||||||
|
+ ")");
|
||||||
|
if (prevLayer != null) {
|
||||||
|
InputType[] inputTypes = new InputType[1];
|
||||||
|
InputPreProcessor preprocessor;
|
||||||
|
if (prevLayer.isInputPreProcessor()) {
|
||||||
|
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
||||||
|
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
||||||
|
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||||
|
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||||
|
} else {
|
||||||
|
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
||||||
|
preprocessor = layer.getInputPreprocessor(inputTypes);
|
||||||
|
if (preprocessor != null) {
|
||||||
|
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||||
|
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||||
|
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
|
||||||
|
}
|
||||||
|
if (preprocessor != null) {
|
||||||
|
|
||||||
|
Map<Integer, InputPreProcessor> map = new HashMap<>();
|
||||||
|
map.put(layerIndex, preprocessor);
|
||||||
|
modelBuilder.inputPreProcessors(map);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelBuilder.layer(layerIndex++, layer.getLayer());
|
||||||
//don't forcibly override for keras import
|
} else if (layer.getVertex() != null)
|
||||||
modelBuilder.overrideNinUponBuild(false);
|
throw new InvalidKerasConfigurationException(
|
||||||
/* Add layers one at a time. */
|
"Cannot add vertex to NeuralNetConfiguration (class name "
|
||||||
KerasLayer prevLayer = null;
|
+ layer.getClassName()
|
||||||
int layerIndex = 0;
|
+ ", layer name "
|
||||||
for (KerasLayer layer : this.layersOrdered) {
|
+ layer.getName()
|
||||||
if (layer.isLayer()) {
|
+ ")");
|
||||||
int nbInbound = layer.getInboundLayerNames().size();
|
prevLayer = layer;
|
||||||
if (nbInbound != 1)
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
|
||||||
+ nbInbound + " for layer " + layer.getName() + ")");
|
|
||||||
if (prevLayer != null) {
|
|
||||||
InputType[] inputTypes = new InputType[1];
|
|
||||||
InputPreProcessor preprocessor;
|
|
||||||
if (prevLayer.isInputPreProcessor()) {
|
|
||||||
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
|
||||||
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
|
||||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
|
||||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
|
||||||
} else {
|
|
||||||
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
|
||||||
preprocessor = layer.getInputPreprocessor(inputTypes);
|
|
||||||
if(preprocessor != null) {
|
|
||||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
|
||||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
|
|
||||||
|
|
||||||
}
|
|
||||||
if (preprocessor != null)
|
|
||||||
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
modelBuilder.layer(layerIndex++, layer.getLayer());
|
|
||||||
} else if (layer.getVertex() != null)
|
|
||||||
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
|
|
||||||
+ layer.getClassName() + ", layer name " + layer.getName() + ")");
|
|
||||||
prevLayer = layer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
|
||||||
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
|
||||||
modelBuilder.backpropType(BackpropType.TruncatedBPTT)
|
|
||||||
.tbpttFwdLength(truncatedBPTT)
|
|
||||||
.tbpttBackLength(truncatedBPTT);
|
|
||||||
else
|
|
||||||
modelBuilder.backpropType(BackpropType.Standard);
|
|
||||||
|
|
||||||
NeuralNetConfiguration build = modelBuilder.build();
|
|
||||||
|
|
||||||
|
|
||||||
return build;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
||||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
|
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
||||||
*
|
modelBuilder
|
||||||
* @return MultiLayerNetwork
|
.backpropType(BackpropType.TruncatedBPTT)
|
||||||
*/
|
.tbpttFwdLength(truncatedBPTT)
|
||||||
public MultiLayerNetwork getMultiLayerNetwork()
|
.tbpttBackLength(truncatedBPTT);
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
else modelBuilder.backpropType(BackpropType.Standard);
|
||||||
return getMultiLayerNetwork(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
NeuralNetConfiguration build = modelBuilder.build();
|
||||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
|
|
||||||
*
|
return build;
|
||||||
* @return MultiLayerNetwork
|
}
|
||||||
*/
|
|
||||||
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
|
/**
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
|
*
|
||||||
model.init();
|
* @return MultiLayerNetwork
|
||||||
if (importWeights)
|
*/
|
||||||
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
|
public MultiLayerNetwork getMultiLayerNetwork()
|
||||||
return model;
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
}
|
return getMultiLayerNetwork(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
|
||||||
|
*
|
||||||
|
* @return MultiLayerNetwork
|
||||||
|
*/
|
||||||
|
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
|
||||||
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
|
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
|
||||||
|
model.init();
|
||||||
|
if (importWeights)
|
||||||
|
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A fluent API to configure and create artificial neural networks
|
* A fluent API to configure and create artificial neural networks
|
||||||
|
@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB
|
||||||
public class NN {
|
public class NN {
|
||||||
|
|
||||||
|
|
||||||
public static NeuralNetConfigurationBuilder<?, ?> net() {
|
public static NeuralNetConfigurationBuilder<?, ?> nn() {
|
||||||
return NeuralNetConfiguration.builder();
|
return NeuralNetConfiguration.builder();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
@Getter @Setter @NonNull @lombok.Builder.Default
|
@Getter @Setter @NonNull @lombok.Builder.Default
|
||||||
protected BackpropType backpropType = BackpropType.Standard;
|
protected BackpropType backpropType = BackpropType.Standard;
|
||||||
|
|
||||||
@Getter @lombok.Builder.Default
|
@Getter @Setter @Singular
|
||||||
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
||||||
/**
|
/**
|
||||||
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
|
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
|
||||||
|
@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* @param processor what to use to preProcess the data.
|
* @param processor what to use to preProcess the data.
|
||||||
* @return builder pattern
|
* @return builder pattern
|
||||||
*/
|
*/
|
||||||
public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
|
//public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
|
||||||
if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>();
|
// inputPreProcessors$value.put(layer, processor);
|
||||||
inputPreProcessors$value.put(layer, processor);
|
// inputPreProcessors$set = true;
|
||||||
inputPreProcessors$set = true;
|
// return self();
|
||||||
return self();
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set layer at index
|
* Set layer at index
|
||||||
|
|
|
@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.*;
|
import com.fasterxml.jackson.databind.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
|
@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
@NonNull
|
@NonNull
|
||||||
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
|
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
|
||||||
if (inputPreProcessor != null) {
|
if (inputPreProcessor != null) {
|
||||||
|
inputPreProcessors = new HashMap<>(inputPreProcessors);
|
||||||
inputPreProcessors.put(i, inputPreProcessor);
|
inputPreProcessors.put(i, inputPreProcessor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
obj.getClass().getSimpleName());
|
obj.getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
// make sure the indexes are sequenced properly
|
||||||
|
AtomicInteger i = new AtomicInteger();
|
||||||
|
ret.forEach(obj -> {
|
||||||
|
obj.setIndex(i.getAndIncrement());
|
||||||
|
});
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Invalid input for Convolution layer (layer name=\""
|
"Invalid input for Convolution layer (layer name=\""
|
||||||
+ getName()
|
+ getName()
|
||||||
+ "\"): Expected CNN input, got "
|
+ "\" at index '"+getIndex()+"') : Expected CNN input, got "
|
||||||
+ inputType);
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
* @param kernelSize kernel size
|
* @param kernelSize kernel size
|
||||||
*/
|
*/
|
||||||
public B kernelSize(int... kernelSize) {
|
public B kernelSize(int... kernelSize) {
|
||||||
this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
|
//this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
|
||||||
|
this.kernelSize$value = kernelSize;
|
||||||
this.kernelSize$set = true;
|
this.kernelSize$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
* @param stride kernel size
|
* @param stride kernel size
|
||||||
*/
|
*/
|
||||||
public B stride(int... stride) {
|
public B stride(int... stride) {
|
||||||
this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
|
//this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
|
||||||
|
this.stride$value = stride;
|
||||||
this.stride$set = true;
|
this.stride$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
* @param padding kernel size
|
* @param padding kernel size
|
||||||
*/
|
*/
|
||||||
public B padding(int... padding) {
|
public B padding(int... padding) {
|
||||||
this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
|
//this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
|
||||||
|
this.padding$value = padding;
|
||||||
this.padding$set = true;
|
this.padding$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
* @param dilation kernel size
|
* @param dilation kernel size
|
||||||
*/
|
*/
|
||||||
public B dilation(int... dilation) {
|
public B dilation(int... dilation) {
|
||||||
this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
|
//this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
|
||||||
|
this.dilation$value = dilation;
|
||||||
this.dilation$set = true;
|
this.dilation$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,14 +20,19 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import lombok.extern.jackson.Jacksonized;
|
import lombok.extern.jackson.Jacksonized;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
|
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
|
||||||
|
@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer {
|
||||||
boolean initializeParams,
|
boolean initializeParams,
|
||||||
DataType networkDataType) {
|
DataType networkDataType) {
|
||||||
setNetConfiguration(conf);
|
setNetConfiguration(conf);
|
||||||
|
|
||||||
|
|
||||||
LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
|
LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
runInheritance();
|
runInheritance();
|
||||||
|
@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer {
|
||||||
getName(),
|
getName(),
|
||||||
Deconvolution2DLayer.class);
|
Deconvolution2DLayer.class);
|
||||||
}
|
}
|
||||||
|
@Slf4j
|
||||||
private static final class Deconvolution2DBuilderImpl
|
private static final class Deconvolution2DBuilderImpl
|
||||||
extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
|
extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
|
||||||
public Deconvolution2D build() {
|
public Deconvolution2D build() {
|
||||||
Deconvolution2D l = new Deconvolution2D(this);
|
Deconvolution2D l = new Deconvolution2D(this);
|
||||||
|
if( l.getConvolutionMode() == ConvolutionMode.Same
|
||||||
|
&& IntStream.of(l.getPadding()).sum() != 0) {
|
||||||
|
log.warn("Invalid input for layer '{}'. "
|
||||||
|
+ "You cannot have a padding of {} when Convolution Mode is set to 'Same'."
|
||||||
|
+ " Padding will be ignored."
|
||||||
|
, l.getName(), l.getPadding());
|
||||||
|
}
|
||||||
|
/* strides * (input_size-1) + kernel_size - 2*padding */
|
||||||
|
//TODO: This is wrong, also depends on convolutionMode, etc ...
|
||||||
|
/*l.nOut = l.getStride()[0] * (l.getNIn()-1)
|
||||||
|
+ IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b)
|
||||||
|
- 2L * IntStream.of(l.getPadding()).sum();
|
||||||
|
*/
|
||||||
|
//l.nOut =264;
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,7 @@ public abstract class LayerConfiguration
|
||||||
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
|
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
|
||||||
|
|
||||||
@Getter @Setter protected String name;
|
@Getter @Setter protected String name;
|
||||||
|
@Getter @Setter private int index;
|
||||||
@Getter @Setter protected List<LayerConstraint> allParamConstraints;
|
@Getter @Setter protected List<LayerConstraint> allParamConstraints;
|
||||||
@Getter @Setter protected List<LayerConstraint> weightConstraints;
|
@Getter @Setter protected List<LayerConstraint> weightConstraints;
|
||||||
@Getter @Setter protected List<LayerConstraint> biasConstraints;
|
@Getter @Setter protected List<LayerConstraint> biasConstraints;
|
||||||
|
@ -72,6 +73,7 @@ public abstract class LayerConfiguration
|
||||||
/** The type of the layer, basically defines the base class and its properties */
|
/** The type of the layer, basically defines the base class and its properties */
|
||||||
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
|
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of parameters this layer has a result of its configuration
|
* Number of parameters this layer has a result of its configuration
|
||||||
* @return number or parameters
|
* @return number or parameters
|
||||||
|
@ -80,7 +82,6 @@ public abstract class LayerConfiguration
|
||||||
return initializer().numParams(this);
|
return initializer().numParams(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A reference to the neural net configuration. This field is excluded from json serialization as
|
* A reference to the neural net configuration. This field is excluded from json serialization as
|
||||||
* well as from equals check to avoid circular referenced.
|
* well as from equals check to avoid circular referenced.
|
||||||
|
|
|
@ -37,122 +37,166 @@ import org.nd4j.linalg.api.shape.Shape;
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(exclude = {"shape"})
|
@EqualsAndHashCode(exclude = {"shape"})
|
||||||
public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
|
public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
|
||||||
private long inputHeight;
|
private long inputHeight;
|
||||||
private long inputWidth;
|
private long inputWidth;
|
||||||
private long numChannels;
|
private long numChannels;
|
||||||
|
|
||||||
@Getter(AccessLevel.NONE)
|
@Getter(AccessLevel.NONE)
|
||||||
@Setter(AccessLevel.NONE)
|
@Setter(AccessLevel.NONE)
|
||||||
private long[] shape;
|
private long[] shape;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reshape to a channels x rows x columns tensor
|
* Reshape to a channels x rows x columns tensor
|
||||||
*
|
*
|
||||||
* @param inputHeight the columns
|
* @param inputHeight the columns
|
||||||
* @param inputWidth the rows
|
* @param inputWidth the rows
|
||||||
* @param numChannels the channels
|
* @param numChannels the channels
|
||||||
*/
|
*/
|
||||||
@JsonCreator
|
@JsonCreator
|
||||||
public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight,
|
public FeedForwardToCnnPreProcessor(
|
||||||
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) {
|
@JsonProperty("inputHeight") long inputHeight,
|
||||||
this.inputHeight = inputHeight;
|
@JsonProperty("inputWidth") long inputWidth,
|
||||||
this.inputWidth = inputWidth;
|
@JsonProperty("numChannels") long numChannels) {
|
||||||
this.numChannels = numChannels;
|
this.inputHeight = inputHeight;
|
||||||
|
this.inputWidth = inputWidth;
|
||||||
|
this.numChannels = numChannels;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Reshape to a channels x rows x columns tensor
|
||||||
|
*
|
||||||
|
* @param inputHeight the columns
|
||||||
|
* @param inputWidth the rows
|
||||||
|
*/
|
||||||
|
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
|
||||||
|
this.inputHeight = inputHeight;
|
||||||
|
this.inputWidth = inputWidth;
|
||||||
|
this.numChannels = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
this.shape = input.shape();
|
||||||
|
if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
|
||||||
|
|
||||||
|
if (input.columns() != inputWidth * inputHeight * numChannels)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Invalid input: expect output columns must be equal to rows "
|
||||||
|
+ inputHeight
|
||||||
|
+ " x columns "
|
||||||
|
+ inputWidth
|
||||||
|
+ " x channels "
|
||||||
|
+ numChannels
|
||||||
|
+ " but was instead "
|
||||||
|
+ Arrays.toString(input.shape()));
|
||||||
|
|
||||||
|
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
|
||||||
|
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||||
|
|
||||||
|
return workspaceMgr.leverageTo(
|
||||||
|
ArrayType.ACTIVATIONS,
|
||||||
|
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
// return 4 dimensions
|
||||||
|
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
|
||||||
|
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
|
||||||
|
|
||||||
|
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
|
||||||
|
if (epsilons.rank() == 2)
|
||||||
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); // should never happen
|
||||||
|
|
||||||
|
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
|
||||||
this.inputHeight = inputHeight;
|
}
|
||||||
this.inputWidth = inputWidth;
|
|
||||||
this.numChannels = 1;
|
@Override
|
||||||
|
public FeedForwardToCnnPreProcessor clone() {
|
||||||
|
try {
|
||||||
|
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
|
||||||
|
if (clone.shape != null) clone.shape = clone.shape.clone();
|
||||||
|
return clone;
|
||||||
|
} catch (CloneNotSupportedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
public InputType getOutputType(InputType inputType) {
|
||||||
this.shape = input.shape();
|
|
||||||
if (input.rank() == 4)
|
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
|
|
||||||
|
|
||||||
if (input.columns() != inputWidth * inputHeight * numChannels)
|
switch (inputType.getType()) {
|
||||||
throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows "
|
case FF:
|
||||||
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels
|
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
|
||||||
+ " but was instead " + Arrays.toString(input.shape()));
|
val expSize = inputHeight * inputWidth * numChannels;
|
||||||
|
if (c.getSize() != expSize) {
|
||||||
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
|
throw new IllegalStateException(
|
||||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
"Invalid input: expected FeedForward input of size "
|
||||||
|
+ expSize
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS,
|
+ " = (d="
|
||||||
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
|
+ numChannels
|
||||||
}
|
+ " * w="
|
||||||
|
+ inputWidth
|
||||||
@Override
|
+ " * h="
|
||||||
// return 4 dimensions
|
+ inputHeight
|
||||||
public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
+ "), got "
|
||||||
if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
|
+ inputType);
|
||||||
epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
|
|
||||||
|
|
||||||
if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
|
|
||||||
if (epsilons.rank() == 2)
|
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); //should never happen
|
|
||||||
|
|
||||||
return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
|
|
||||||
}
|
}
|
||||||
|
return InputType.convolutional(inputHeight, inputWidth, numChannels);
|
||||||
|
case CNN:
|
||||||
|
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
|
||||||
|
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
|
if (c2.getChannels() != numChannels
|
||||||
}
|
|| c2.getHeight() != inputHeight
|
||||||
|
|| c2.getWidth() != inputWidth) {
|
||||||
|
throw new IllegalStateException(
|
||||||
@Override
|
"Invalid input: Got CNN input type with (d,w,h)=("
|
||||||
public FeedForwardToCnnPreProcessor clone() {
|
+ c2.getChannels()
|
||||||
try {
|
+ ","
|
||||||
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
|
+ c2.getWidth()
|
||||||
if (clone.shape != null)
|
+ ","
|
||||||
clone.shape = clone.shape.clone();
|
+ c2.getHeight()
|
||||||
return clone;
|
+ ") but expected ("
|
||||||
} catch (CloneNotSupportedException e) {
|
+ numChannels
|
||||||
throw new RuntimeException(e);
|
+ ","
|
||||||
|
+ inputHeight
|
||||||
|
+ ","
|
||||||
|
+ inputWidth
|
||||||
|
+ ")");
|
||||||
}
|
}
|
||||||
}
|
return c2;
|
||||||
|
case CNNFlat:
|
||||||
@Override
|
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
|
||||||
public InputType getOutputType(InputType inputType) {
|
if (c3.getDepth() != numChannels
|
||||||
|
|| c3.getHeight() != inputHeight
|
||||||
switch (inputType.getType()) {
|
|| c3.getWidth() != inputWidth) {
|
||||||
case FF:
|
throw new IllegalStateException(
|
||||||
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
|
"Invalid input: Got CNN input type with (d,w,h)=("
|
||||||
val expSize = inputHeight * inputWidth * numChannels;
|
+ c3.getDepth()
|
||||||
if (c.getSize() != expSize) {
|
+ ","
|
||||||
throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
|
+ c3.getWidth()
|
||||||
+ " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
|
+ ","
|
||||||
+ inputType);
|
+ c3.getHeight()
|
||||||
}
|
+ ") but expected ("
|
||||||
return InputType.convolutional(inputHeight, inputWidth, numChannels);
|
+ numChannels
|
||||||
case CNN:
|
+ ","
|
||||||
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
|
+ inputHeight
|
||||||
|
+ ","
|
||||||
if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
|
+ inputWidth
|
||||||
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
|
+ ")");
|
||||||
+ "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
|
|
||||||
+ "," + inputHeight + "," + inputWidth + ")");
|
|
||||||
}
|
|
||||||
return c2;
|
|
||||||
case CNNFlat:
|
|
||||||
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
|
|
||||||
if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
|
|
||||||
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
|
|
||||||
+ "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
|
|
||||||
+ "," + inputHeight + "," + inputWidth + ")");
|
|
||||||
}
|
|
||||||
return c3.getUnflattenedType();
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException("Invalid input type: got " + inputType);
|
|
||||||
}
|
}
|
||||||
|
return c3.getUnflattenedType();
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException("Invalid input type: got " + inputType);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
|
public Pair<INDArray, MaskState> feedForwardMaskArray(
|
||||||
int minibatchSize) {
|
INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||||
//Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
|
// Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
|
||||||
return new Pair<>(maskArray, currentMaskState);
|
return new Pair<>(maskArray, currentMaskState);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -369,7 +369,7 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
|
||||||
|
|
||||||
protected String layerId() {
|
protected String layerId() {
|
||||||
String name = this.layerConfiguration.getName();
|
String name = this.layerConfiguration.getName();
|
||||||
return "(layer name: "
|
return "(network: " + getNetConfiguration().getName() + " layer name: "
|
||||||
+ (name == null ? "\"\"" : name)
|
+ (name == null ? "\"\"" : name)
|
||||||
+ ", layer index: "
|
+ ", layer index: "
|
||||||
+ index
|
+ index
|
||||||
|
|
|
@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
int[] args = new int[] {
|
int[] args = new int[] {
|
||||||
(int)kH, (int)kW, strides[0], strides[1],
|
(int)kH, (int)kW, strides[0], strides[1],
|
||||||
pad[0], pad[1], dilation[0], dilation[1], sameMode,
|
pad[0], pad[1], dilation[0], dilation[1], sameMode //,
|
||||||
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
//nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
||||||
};
|
};
|
||||||
|
|
||||||
INDArray delta;
|
INDArray delta;
|
||||||
|
@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
int[] args = new int[] {
|
int[] args = new int[] {
|
||||||
kH, kW, strides[0], strides[1],
|
kH, kW, strides[0], strides[1],
|
||||||
pad[0], pad[1], dilation[0], dilation[1], sameMode,
|
pad[0], pad[1], dilation[0], dilation[1], sameMode //,
|
||||||
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
//nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
|
||||||
};
|
};
|
||||||
|
|
||||||
//DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
|
//DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
|
||||||
|
@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
} else {
|
} else {
|
||||||
opInputs = new INDArray[]{input, weights};
|
opInputs = new INDArray[]{input, weights};
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* 2D deconvolution implementation
|
||||||
|
*
|
||||||
|
* IntArgs:
|
||||||
|
* 0: kernel height
|
||||||
|
* 1: kernel width
|
||||||
|
* 2: stride height
|
||||||
|
* 3: stride width
|
||||||
|
* 4: padding height
|
||||||
|
* 5: padding width
|
||||||
|
* 6: dilation height
|
||||||
|
* 7: dilation width
|
||||||
|
* 8: same mode: 0 false, 1 true
|
||||||
|
*/
|
||||||
CustomOp op = DynamicCustomOp.builder("deconv2d")
|
CustomOp op = DynamicCustomOp.builder("deconv2d")
|
||||||
.addInputs(opInputs)
|
.addInputs(opInputs)
|
||||||
.addIntegerArguments(args)
|
.addIntegerArguments(args)
|
||||||
|
|
|
@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
|
||||||
LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
|
LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
|
||||||
layers[i] =
|
layers[i] =
|
||||||
lc.instantiate(
|
lc.instantiate(
|
||||||
lc.getNetConfiguration(),
|
this.getNetConfiguration(),
|
||||||
trainingListeners,
|
trainingListeners,
|
||||||
i,
|
i,
|
||||||
paramsView,
|
paramsView,
|
||||||
|
|
|
@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
|
||||||
|
|
||||||
params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
|
params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
|
||||||
conf.getNetConfiguration().addNetWideVariable(GAMMA);
|
conf.getNetConfiguration().addNetWideVariable(GAMMA);
|
||||||
|
conf.addVariable(GAMMA);
|
||||||
params.put(BETA, createBeta(conf, betaView, initializeParams));
|
params.put(BETA, createBeta(conf, betaView, initializeParams));
|
||||||
conf.getNetConfiguration().addNetWideVariable(BETA);
|
conf.getNetConfiguration().addNetWideVariable(BETA);
|
||||||
|
conf.addVariable(BETA);
|
||||||
|
|
||||||
meanOffset = 2 * nOut;
|
meanOffset = 2 * nOut;
|
||||||
}
|
}
|
||||||
|
@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
|
||||||
|
|
||||||
params.put(GLOBAL_MEAN, globalMeanView);
|
params.put(GLOBAL_MEAN, globalMeanView);
|
||||||
conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
|
conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
|
||||||
|
conf.addVariable(GLOBAL_MEAN);
|
||||||
if(layer.isUseLogStd()){
|
if(layer.isUseLogStd()){
|
||||||
params.put(GLOBAL_LOG_STD, globalVarView);
|
params.put(GLOBAL_LOG_STD, globalVarView);
|
||||||
conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
|
conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
|
||||||
|
conf.addVariable(GLOBAL_LOG_STD);
|
||||||
} else {
|
} else {
|
||||||
params.put(GLOBAL_VAR, globalVarView);
|
params.put(GLOBAL_VAR, globalVarView);
|
||||||
conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
|
conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
|
||||||
|
conf.addVariable(GLOBAL_VAR);
|
||||||
}
|
}
|
||||||
|
|
||||||
return params;
|
return params;
|
||||||
|
|
|
@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
|
||||||
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
||||||
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
|
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
|
||||||
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
|
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
|
||||||
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
|
conf.addVariable(WEIGHT_KEY);
|
||||||
|
conf.addVariable(BIAS_KEY);
|
||||||
} else {
|
} else {
|
||||||
INDArray weightView = paramsView;
|
INDArray weightView = paramsView;
|
||||||
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
||||||
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
|
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
|
||||||
|
conf.addVariable(WEIGHT_KEY);
|
||||||
}
|
}
|
||||||
|
|
||||||
return params;
|
return params;
|
||||||
|
|
|
@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true
|
||||||
|
|
||||||
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
|
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
|
||||||
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
|
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
|
||||||
|
#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true
|
||||||
|
|
||||||
# When configured, Gradle will run in incubating parallel mode.
|
# When configured, Gradle will run in incubating parallel mode.
|
||||||
# This option should only be used with decoupled projects. More details, visit
|
# This option should only be used with decoupled projects. More details, visit
|
||||||
|
|
Binary file not shown.
|
@ -69,18 +69,18 @@ app_path=$0
|
||||||
|
|
||||||
# Need this for daisy-chained symlinks.
|
# Need this for daisy-chained symlinks.
|
||||||
while
|
while
|
||||||
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
|
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
|
||||||
[ -h "$app_path" ]
|
[ -h "$app_path" ]
|
||||||
do
|
do
|
||||||
ls=$(ls -ld "$app_path")
|
ls=$( ls -ld "$app_path" )
|
||||||
link=${ls#*' -> '}
|
link=${ls#*' -> '}
|
||||||
case $link in #(
|
case $link in #(
|
||||||
/*) app_path=$link ;; #(
|
/*) app_path=$link ;; #(
|
||||||
*) app_path=$APP_HOME$link ;;
|
*) app_path=$APP_HOME$link ;;
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit
|
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
|
||||||
|
|
||||||
APP_NAME="Gradle"
|
APP_NAME="Gradle"
|
||||||
APP_BASE_NAME=${0##*/}
|
APP_BASE_NAME=${0##*/}
|
||||||
|
@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||||
MAX_FD=maximum
|
MAX_FD=maximum
|
||||||
|
|
||||||
warn() {
|
warn () {
|
||||||
echo "$*"
|
echo "$*"
|
||||||
} >&2
|
} >&2
|
||||||
|
|
||||||
die() {
|
die () {
|
||||||
echo
|
echo
|
||||||
echo "$*"
|
echo "$*"
|
||||||
echo
|
echo
|
||||||
exit 1
|
exit 1
|
||||||
} >&2
|
} >&2
|
||||||
|
|
||||||
# OS specific support (must be 'true' or 'false').
|
# OS specific support (must be 'true' or 'false').
|
||||||
|
@ -107,52 +107,51 @@ cygwin=false
|
||||||
msys=false
|
msys=false
|
||||||
darwin=false
|
darwin=false
|
||||||
nonstop=false
|
nonstop=false
|
||||||
case "$(uname)" in #(
|
case "$( uname )" in #(
|
||||||
CYGWIN*) cygwin=true ;; #(
|
CYGWIN* ) cygwin=true ;; #(
|
||||||
Darwin*) darwin=true ;; #(
|
Darwin* ) darwin=true ;; #(
|
||||||
MSYS* | MINGW*) msys=true ;; #(
|
MSYS* | MINGW* ) msys=true ;; #(
|
||||||
NONSTOP*) nonstop=true ;;
|
NONSTOP* ) nonstop=true ;;
|
||||||
esac
|
esac
|
||||||
|
|
||||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||||
|
|
||||||
|
|
||||||
# Determine the Java command to use to start the JVM.
|
# Determine the Java command to use to start the JVM.
|
||||||
if [ -n "$JAVA_HOME" ]; then
|
if [ -n "$JAVA_HOME" ] ; then
|
||||||
if [ -x "$JAVA_HOME/jre/sh/java" ]; then
|
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||||
# IBM's JDK on AIX uses strange locations for the executables
|
# IBM's JDK on AIX uses strange locations for the executables
|
||||||
JAVACMD=$JAVA_HOME/jre/sh/java
|
JAVACMD=$JAVA_HOME/jre/sh/java
|
||||||
else
|
else
|
||||||
JAVACMD=$JAVA_HOME/bin/java
|
JAVACMD=$JAVA_HOME/bin/java
|
||||||
fi
|
fi
|
||||||
if [ ! -x "$JAVACMD" ]; then
|
if [ ! -x "$JAVACMD" ] ; then
|
||||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||||
|
|
||||||
Please set the JAVA_HOME variable in your environment to match the
|
Please set the JAVA_HOME variable in your environment to match the
|
||||||
location of your Java installation."
|
location of your Java installation."
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
JAVACMD=java
|
JAVACMD=java
|
||||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||||
|
|
||||||
Please set the JAVA_HOME variable in your environment to match the
|
Please set the JAVA_HOME variable in your environment to match the
|
||||||
location of your Java installation."
|
location of your Java installation."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Increase the maximum file descriptors if we can.
|
# Increase the maximum file descriptors if we can.
|
||||||
if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then
|
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
||||||
case $MAX_FD in #(
|
case $MAX_FD in #(
|
||||||
max*)
|
max*)
|
||||||
MAX_FD=$(ulimit -H -n) ||
|
MAX_FD=$( ulimit -H -n ) ||
|
||||||
warn "Could not query maximum file descriptor limit"
|
warn "Could not query maximum file descriptor limit"
|
||||||
;;
|
esac
|
||||||
esac
|
case $MAX_FD in #(
|
||||||
case $MAX_FD in #(
|
'' | soft) :;; #(
|
||||||
'' | soft) : ;; #(
|
*)
|
||||||
*)
|
ulimit -n "$MAX_FD" ||
|
||||||
ulimit -n "$MAX_FD" ||
|
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||||
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
esac
|
||||||
;;
|
|
||||||
esac
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Collect all arguments for the java command, stacking in reverse order:
|
# Collect all arguments for the java command, stacking in reverse order:
|
||||||
|
@ -164,36 +163,34 @@ fi
|
||||||
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
|
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
|
||||||
|
|
||||||
# For Cygwin or MSYS, switch paths to Windows format before running java
|
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||||
if "$cygwin" || "$msys"; then
|
if "$cygwin" || "$msys" ; then
|
||||||
APP_HOME=$(cygpath --path --mixed "$APP_HOME")
|
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
|
||||||
CLASSPATH=$(cygpath --path --mixed "$CLASSPATH")
|
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
|
||||||
|
|
||||||
JAVACMD=$(cygpath --unix "$JAVACMD")
|
JAVACMD=$( cygpath --unix "$JAVACMD" )
|
||||||
|
|
||||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||||
for arg; do
|
for arg do
|
||||||
if
|
if
|
||||||
case $arg in #(
|
case $arg in #(
|
||||||
-*) false ;; # don't mess with options #(
|
-*) false ;; # don't mess with options #(
|
||||||
/?*)
|
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
||||||
t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
[ -e "$t" ] ;; #(
|
||||||
[ -e "$t" ]
|
*) false ;;
|
||||||
;; #(
|
esac
|
||||||
*) false ;;
|
then
|
||||||
esac
|
arg=$( cygpath --path --ignore --mixed "$arg" )
|
||||||
then
|
fi
|
||||||
arg=$(cygpath --path --ignore --mixed "$arg")
|
# Roll the args list around exactly as many times as the number of
|
||||||
fi
|
# args, so each arg winds up back in the position where it started, but
|
||||||
# Roll the args list around exactly as many times as the number of
|
# possibly modified.
|
||||||
# args, so each arg winds up back in the position where it started, but
|
#
|
||||||
# possibly modified.
|
# NB: a `for` loop captures its iteration list before it begins, so
|
||||||
#
|
# changing the positional parameters here affects neither the number of
|
||||||
# NB: a `for` loop captures its iteration list before it begins, so
|
# iterations, nor the values presented in `arg`.
|
||||||
# changing the positional parameters here affects neither the number of
|
shift # remove old arg
|
||||||
# iterations, nor the values presented in `arg`.
|
set -- "$@" "$arg" # push replacement arg
|
||||||
shift # remove old arg
|
done
|
||||||
set -- "$@" "$arg" # push replacement arg
|
|
||||||
done
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Collect all arguments for the java command;
|
# Collect all arguments for the java command;
|
||||||
|
@ -203,15 +200,10 @@ fi
|
||||||
# * put everything else in single quotes, so that it's not re-expanded.
|
# * put everything else in single quotes, so that it's not re-expanded.
|
||||||
|
|
||||||
set -- \
|
set -- \
|
||||||
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
||||||
-classpath "$CLASSPATH" \
|
-classpath "$CLASSPATH" \
|
||||||
org.gradle.wrapper.GradleWrapperMain \
|
org.gradle.wrapper.GradleWrapperMain \
|
||||||
"$@"
|
"$@"
|
||||||
|
|
||||||
# Stop when "xargs" is not available.
|
|
||||||
if ! command -v xargs >/dev/null 2>&1; then
|
|
||||||
die "xargs is not available"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Use "xargs" to parse quoted args.
|
# Use "xargs" to parse quoted args.
|
||||||
#
|
#
|
||||||
|
@ -233,10 +225,10 @@ fi
|
||||||
#
|
#
|
||||||
|
|
||||||
eval "set -- $(
|
eval "set -- $(
|
||||||
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
|
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
|
||||||
xargs -n1 |
|
xargs -n1 |
|
||||||
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
|
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
|
||||||
tr '\n' ' '
|
tr '\n' ' '
|
||||||
)" '"$@"'
|
)" '"$@"'
|
||||||
|
|
||||||
exec "$JAVACMD" "$@"
|
exec "$JAVACMD" "$@"
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
@rem limitations under the License.
|
@rem limitations under the License.
|
||||||
@rem
|
@rem
|
||||||
|
|
||||||
@if "%DEBUG%"=="" @echo off
|
@if "%DEBUG%" == "" @echo off
|
||||||
@rem ##########################################################################
|
@rem ##########################################################################
|
||||||
@rem
|
@rem
|
||||||
@rem Gradle startup script for Windows
|
@rem Gradle startup script for Windows
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
if "%OS%"=="Windows_NT" setlocal
|
if "%OS%"=="Windows_NT" setlocal
|
||||||
|
|
||||||
set DIRNAME=%~dp0
|
set DIRNAME=%~dp0
|
||||||
if "%DIRNAME%"=="" set DIRNAME=.
|
if "%DIRNAME%" == "" set DIRNAME=.
|
||||||
set APP_BASE_NAME=%~n0
|
set APP_BASE_NAME=%~n0
|
||||||
set APP_HOME=%DIRNAME%
|
set APP_HOME=%DIRNAME%
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
|
||||||
|
|
||||||
set JAVA_EXE=java.exe
|
set JAVA_EXE=java.exe
|
||||||
%JAVA_EXE% -version >NUL 2>&1
|
%JAVA_EXE% -version >NUL 2>&1
|
||||||
if %ERRORLEVEL% equ 0 goto execute
|
if "%ERRORLEVEL%" == "0" goto execute
|
||||||
|
|
||||||
echo.
|
echo.
|
||||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||||
|
@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||||
|
|
||||||
:end
|
:end
|
||||||
@rem End local scope for the variables with windows NT shell
|
@rem End local scope for the variables with windows NT shell
|
||||||
if %ERRORLEVEL% equ 0 goto mainEnd
|
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||||
|
|
||||||
:fail
|
:fail
|
||||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||||
rem the _cmd.exe /c_ return code!
|
rem the _cmd.exe /c_ return code!
|
||||||
set EXIT_CODE=%ERRORLEVEL%
|
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||||
if %EXIT_CODE% equ 0 set EXIT_CODE=1
|
exit /b 1
|
||||||
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
|
|
||||||
exit /b %EXIT_CODE%
|
|
||||||
|
|
||||||
:mainEnd
|
:mainEnd
|
||||||
if "%OS%"=="Windows_NT" endlocal
|
if "%OS%"=="Windows_NT" endlocal
|
||||||
|
|
Loading…
Reference in New Issue