gan example

Signed-off-by: brian <brian@brutex.de>
enhance-build-infrastructure
Brian Rosenberger 2023-08-07 10:32:39 +02:00
parent 1b3338f809
commit dd151aec3f
20 changed files with 1267 additions and 838 deletions

View File

@ -1,115 +1,48 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan; package net.brutex.gan;
import static net.brutex.ai.dnn.api.NN.dense;
import java.awt.*; import java.awt.*;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Random; import javax.swing.*;
import java.util.UUID;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.WindowConstants;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Slf4j
public class App { public class App {
private static final double LEARNING_RATE = 0.000002; private static final double LEARNING_RATE = 0.002;
private static final double GRADIENT_THRESHOLD = 100.0; private static final double GRADIENT_THRESHOLD = 100.0;
private static final int X_DIM = 20 ;
private static final int Y_DIM = 20;
private static final int CHANNELS = 1;
private static final int batchSize = 1;
private static final int INPUT = 10;
private static final int OUTPUT_PER_PANEL = 16;
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
private static final int BATCHSIZE = 128;
private static JFrame frame; private static JFrame frame;
private static JFrame frame2;
private static JPanel panel; private static JPanel panel;
private static JPanel panel2;
private static final String OUTPUT_DIR = "C:/temp/output/";
private static LayerConfiguration[] genLayers() { private static LayerConfiguration[] genLayers() {
return new LayerConfiguration[] { return new LayerConfiguration[] {
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(Activation.LEAKYRELU).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), dense().nIn(256).nOut(512).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build() ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
}; };
} }
@ -124,65 +57,51 @@ public class App {
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInit(WeightInit.XAVIER)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.layersFromArray(genLayers()) .layersFromArray(genLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .name("generator")
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
.build(); .build();
((NeuralNetConfiguration) conf).init();
return conf; return conf;
} }
private static LayerConfiguration[] disLayers() { private static LayerConfiguration[] disLayers() {
return new LayerConfiguration[]{ return new LayerConfiguration[]{
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network dense().nIn(784).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC dense().nIn(1024).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(), dense().nIn(512).nOut(256).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
}; };
} }
private static NeuralNetConfiguration discriminator() { private static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.seed(42) .seed(42)
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.layersFromArray(disLayers()) .layersFromArray(disLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .name("discriminator")
.build(); .build();
((NeuralNetConfiguration) conf).init();
return conf; return conf;
} }
private static NeuralNetConfiguration gan() { private static NeuralNetConfiguration gan() {
LayerConfiguration[] genLayers = genLayers(); LayerConfiguration[] genLayers = genLayers();
LayerConfiguration[] disLayers = Arrays.stream(disLayers()) LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
.map((layer) -> { .map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) { if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build(); return FrozenLayerWithBackprop.builder(layer).build();
} else { } else {
return layer; return layer;
} }
@ -191,174 +110,100 @@ public class App {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42) .seed(42)
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold( 100 ) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInitFn( new WeightInitXavier() ) //this is internal
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
//.activationFn( new ActivationIdentity()) //this is internal
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.layersFromArray(layers) .layersFromArray(layers)
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .name("GAN")
.dataType(DataType.FLOAT)
.build(); .build();
((NeuralNetConfiguration) conf).init();
return conf; return conf;
} }
@Test @Test
public void runTest() throws Exception { public void runTest() throws Exception {
if(! log.isDebugEnabled()) { App.main(null);
log.info("Logging is not set to DEBUG");
} }
else {
log.info("Logging is set to DEBUG");
}
main();
}
public static void main(String... args) throws Exception { public static void main(String... args) throws Exception {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m "); MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
Nd4j.getMemoryManager().setAutoGcWindow(500);
//MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
//FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
ImageTransform tr = new PipelineImageTransform.Builder()
//.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
MultiLayerNetwork gen = new MultiLayerNetwork(generator()); MultiLayerNetwork gen = new MultiLayerNetwork(generator());
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator()); MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
MultiLayerNetwork gan = new MultiLayerNetwork(gan()); MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init(); log.debug("Generator network: {}", gen); gen.init();
dis.init(); log.debug("Discriminator network: {}", dis); dis.init();
gan.init(); log.info("Complete GAN network: {}", gan); gan.init();
copyParams(gen, dis, gan); copyParams(gen, dis, gan);
//gen.addTrainingListeners(new PerformanceListener(15, true, "GEN")); gen.addTrainingListeners(new PerformanceListener(10, true));
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS")); dis.addTrainingListeners(new PerformanceListener(10, true));
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN")); gan.addTrainingListeners(new PerformanceListener(10, true));
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis"));
//System.out.println(gan.toString()); trainData.reset();
//gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
//trainData.reset();
int j = 0; int j = 0;
for (int i = 0; i < 51; i++) { //epoch for (int i = 0; i < 50; i++) {
while (trainData.hasNext()) { while (trainData.hasNext()) {
j++; j++;
DataSet next = trainData.next();
// generate data // generate data
INDArray real = next.getFeatures();//.div(255f); INDArray real = trainData.next().getFeatures().muli(2).subi(1);
int batchSize = (int) real.shape()[0];
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
//if(i%20 == 0) {
frame2 = visualize(new INDArray[]{real}, batchSize,
frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
//}
real.divi(255f);
// int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fakeIn = Nd4j.rand(batchSize, 100);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
//log.info("real has {} items.", real.length());
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1)); DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1)); DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet)); DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data); dis.fit(data);
//dis.fit(data); dis.fit(data);
// Update the discriminator in the GAN network // Update the discriminator in the GAN network
updateGan(gen, dis, gan); updateGan(gen, dis, gan);
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1))); gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
//Visualize and reporting
if (j % 10 == 1) { if (j % 10 == 1) {
System.out.println("Epoch " + i +" Iteration " + j + " Visualizing..."); System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize]; INDArray[] samples = new INDArray[9];
for (int k = 0; k < samples.length; k++) {
//INDArray input = fakeSet2.get(k).getFeatures();
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1)); DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
INDArray input = fakeSet2.get(k).getFeatures();
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
for (int k = 0; k < 9; k++) {
INDArray input = fakeSet2.get(k).getFeatures();
//samples[k] = gen.output(input, false); //samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input); samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
//samples[k] =
samples[k].addi(1f).divi(2f).muli(255f);
} }
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1 visualize(samples);
} }
} }
if (trainData.resetSupported()) {
trainData.reset(); trainData.reset();
} else { // Copy the GANs generator to gen.
log.error("Trainingdata {} does not support reset.", trainData.toString()); //updateGen(gen, gan);
} }
// Copy the GANs generator to gen. // Copy the GANs generator to gen.
updateGen(gen, gan); updateGen(gen, gan);
gen.save(new File("mnist-mlp-generator.dlj")); gen.save(new File("mnist-mlp-generator.dlj"));
} }
}
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length; int genLayerCount = gen.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) { for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) { if (i < genLayerCount) {
if(gan.getLayer(i).getParams() != null) gen.getLayer(i).setParams(gan.getLayer(i).getParams());
gan.getLayer(i).setParams(gen.getLayer(i).getParams());
} else { } else {
if(gan.getLayer(i).getParams() != null) dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
} }
} }
} }
@ -376,98 +221,41 @@ public class App {
} }
} }
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) { private static void visualize(INDArray[] samples) {
if (isOrig) { if (frame == null) {
frame.setTitle("Viz Original"); frame = new JFrame();
} else { frame.setTitle("Viz");
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout()); frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel(); panel = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8)); panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
for (INDArray sample : samples) { frame.add(panel, BorderLayout.CENTER);
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true); frame.setVisible(true);
}
panel.removeAll();
for (INDArray sample : samples) {
panel.add(getImage(sample));
}
frame.revalidate(); frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack(); frame.pack();
return frame;
} }
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) { private static JLabel getImage(INDArray tensor) {
final BufferedImage bi; BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
if(CHANNELS>1) { for (int i = 0; i < 784; i++) {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
} else { bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = X_DIM * Y_DIM;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
for (int x = 0; x < X_DIM; x++) { //step through the rows y
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
pxl++; //next item in INDArray
}
}
} }
ImageIcon orig = new ImageIcon(bi); ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT); Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled); ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled); return new JLabel(scaled);
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
} }
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
} }

View File

@ -0,0 +1,343 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.List;
import javax.imageio.ImageIO;
import javax.swing.*;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.*;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
@Slf4j
public class App2 {
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
static final float COLORSPACE = 255f;
static final int DIMENSIONS = 28;
static final int CHANNELS = 1;
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
final int OUTPUT_PER_PANEL = 10;
final boolean BIAS = true;
static final int BATCHSIZE=128;
private JFrame frame2, frame;
static final String OUTPUT_DIR = "d:/out/";
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
@Test
void runTest() throws IOException {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
ImageTransform tr = new PipelineImageTransform.Builder()
.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
.map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
NeuralNetConfiguration netConfiguration =
NeuralNetConfiguration.builder()
.name("GAN")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.updater(App2Config.UPDATER)
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
dis.init(); log.debug("Discriminator network: {}", dis);
gen.init(); log.debug("Generator network: {}", gen);
gan.init(); log.debug("GAN network: {}", gan);
log.info("Generator Summary:\n{}", gen.summary());
log.info("GAN Summary:\n{}", gan.summary());
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
int j = 0;
for (int i = 0; i < 51; i++) { //epoch
while (trainData.hasNext()) {
j++;
DataSet next = trainData.next();
// generate data
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
//if(i%20 == 0) {
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
//}
//real.divi(255f);
// int batchSize = (int) real.shape()[0];
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
// when generator has TANH as activation - value range is -1 to 1
// when generator has SIGMOID, then range is 0 to 1
fake.addi(1f).divi(2f);
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
//Visualize and reporting
if (j % 10 == 1) {
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
for (int k = 0; k < samples.length; k++) {
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
INDArray input = fakeSet2.get(k).getFeatures();
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
input = input.reshape(1, App2Config.INPUT);
//samples[k] = gen.output(input, false);
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
//samples[k] =
//samples[k].muli(255f);
}
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
}
}
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
// Copy the GANs generator to gen.
updateGen(gen, gan);
log.info("Updated GAN's generator from gen.");
gen.save(new File("mnist-mlp-generator.dlj"));
}
}
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
if (isOrig) {
frame.setTitle("Viz Original");
} else {
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8));
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true);
frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
return frame;
}
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS >1) {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = DIMENSIONS * DIMENSIONS;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
bi.getRaster().setSample(x, y, c, f_pxl);
pxl++; //next item in INDArray
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled);
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage;
if (CHANNELS > 1) {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
} else {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
}
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
}

View File

@ -0,0 +1,176 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import static net.brutex.ai.dnn.api.NN.*;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class App2Config {
public static final int INPUT = 100;
public static final int X_DIM = 28;
public static final int y_DIM = 28;
public static final int CHANNELS = 1;
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
static LayerConfiguration[] genLayerConfig() {
return new LayerConfiguration[] {
/*
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
ActivationLayer.builder().activation(Activation.RELU).build(), /*
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(1,1)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().nOut(CHANNELS).build(),
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(2,2)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
// DropoutLayer.builder(0.001).build(),
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(256).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
};
}
static LayerConfiguration[] disLayerConfig() {
return new LayerConfiguration[] {/*
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
.build(),
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
.build(),
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
BatchNormalization.builder().build(),
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SIGMOID)
.build()
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
.activation(Activation.SIGMOID)
.build() */
dense().nIn(784).nOut(1024).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(1024).nOut(512).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(512).nOut(256).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
};
}
static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("generator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER)
.weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(App2Config.genLayerConfig())
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
.build();
conf.init();
return conf;
}
static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("discriminator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER)
.weightInit(WeightInit.XAVIER)
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(disLayerConfig())
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
conf.init();
return conf;
}
}

View File

@ -52,28 +52,44 @@ public class KerasSequentialModel extends KerasModel {
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
*/ */
public KerasSequentialModel(KerasModelBuilder modelBuilder) public KerasSequentialModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { throws UnsupportedKerasConfigurationException,
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(), IOException,
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(), InvalidKerasConfigurationException {
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape()); this(
modelBuilder.getModelJson(),
modelBuilder.getModelYaml(),
modelBuilder.getWeightsArchive(),
modelBuilder.getWeightsRoot(),
modelBuilder.getTrainingJson(),
modelBuilder.getTrainingArchive(),
modelBuilder.isEnforceTrainingConfig(),
modelBuilder.getInputShape());
} }
/** /**
* (Not recommended) Constructor for Sequential model from model configuration * (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
* (JSON or YAML), training configuration (JSON), weights, and "training mode" * training configuration (JSON), weights, and "training mode" boolean indicator. When built in
* boolean indicator. When built in training mode, certain unsupported configurations * training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these * Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
* will generate warnings but will be otherwise ignored. * otherwise ignored.
* *
* @param modelJson model configuration JSON string * @param modelJson model configuration JSON string
* @param modelYaml model configuration YAML string * @param modelYaml model configuration YAML string
* @param trainingJson training configuration JSON string * @param trainingJson training configuration JSON string
* @throws IOException I/O exception * @throws IOException I/O exception
*/ */
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, public KerasSequentialModel(
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig, String modelJson,
String modelYaml,
Hdf5Archive weightsArchive,
String weightsRoot,
String trainingJson,
Hdf5Archive trainingArchive,
boolean enforceTrainingConfig,
int[] inputShape) int[] inputShape)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { throws IOException,
InvalidKerasConfigurationException,
UnsupportedKerasConfigurationException {
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml); Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config); this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
@ -83,19 +99,29 @@ public class KerasSequentialModel extends KerasModel {
/* Determine model configuration type. */ /* Determine model configuration type. */
if (!modelConfig.containsKey(config.getFieldClassName())) if (!modelConfig.containsKey(config.getFieldClassName()))
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)"); "Could not determine Keras model class (no "
+ config.getFieldClassName()
+ " field found)");
this.className = (String) modelConfig.get(config.getFieldClassName()); this.className = (String) modelConfig.get(config.getFieldClassName());
if (!this.className.equals(config.getFieldClassNameSequential())) if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential() throw new InvalidKerasConfigurationException(
+ " (found " + this.className + ")"); "Model class name must be "
+ config.getFieldClassNameSequential()
+ " (found "
+ this.className
+ ")");
/* Process layer configurations. */ /* Process layer configurations. */
if (!modelConfig.containsKey(config.getModelFieldConfig())) if (!modelConfig.containsKey(config.getModelFieldConfig()))
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)"); "Could not find layer configurations (no "
+ config.getModelFieldConfig()
+ " field found)");
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before. // For consistency
// "config" is now an object containing a "name" and "layers", the latter contain the same data
// as before.
// This change only affects Sequential models. // This change only affects Sequential models.
List<Object> layerList; List<Object> layerList;
try { try {
@ -105,8 +131,7 @@ public class KerasSequentialModel extends KerasModel {
layerList = (List<Object>) layerMap.get("layers"); layerList = (List<Object>) layerMap.get("layers");
} }
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
prepareLayers(layerList);
this.layers = layerPair.getFirst(); this.layers = layerPair.getFirst();
this.layersOrdered = layerPair.getSecond(); this.layersOrdered = layerPair.getSecond();
@ -116,15 +141,18 @@ public class KerasSequentialModel extends KerasModel {
} else { } else {
/* Add placeholder input layer and update lists of input and output layers. */ /* Add placeholder input layer and update lists of input and output layers. */
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape(); int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!"); Preconditions.checkState(
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
inputLayer = new KerasInput("input1", firstLayerInputShape); inputLayer = new KerasInput("input1", firstLayerInputShape);
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder()); inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
this.layers.put(inputLayer.getName(), inputLayer); this.layers.put(inputLayer.getName(), inputLayer);
this.layersOrdered.add(0, inputLayer); this.layersOrdered.add(0, inputLayer);
} }
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName())); this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
this.outputLayerNames = new ArrayList<>( this.outputLayerNames =
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName())); new ArrayList<>(
Collections.singletonList(
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
/* Update each layer's inbound layer list to include (only) previous layer. */ /* Update each layer's inbound layer list to include (only) previous layer. */
KerasLayer prevLayer = null; KerasLayer prevLayer = null;
@ -136,12 +164,13 @@ public class KerasSequentialModel extends KerasModel {
/* Import training configuration. */ /* Import training configuration. */
if (enforceTrainingConfig) { if (enforceTrainingConfig) {
if (trainingJson != null) if (trainingJson != null) importTrainingConfiguration(trainingJson);
importTrainingConfiguration(trainingJson); else
else log.warn("If enforceTrainingConfig is true, a training " + log.warn(
"configuration object has to be provided. Usually the only practical way to do this is to store" + "If enforceTrainingConfig is true, a training "
" your keras model with `model.save('model_path.h5'. If you store model config and weights" + + "configuration object has to be provided. Usually the only practical way to do this is to store"
" separately no training configuration is attached."); + " your keras model with `model.save('model_path.h5'. If you store model config and weights"
+ " separately no training configuration is attached.");
} }
this.outputTypes = inferOutputTypes(inputShape); this.outputTypes = inferOutputTypes(inputShape);
@ -150,9 +179,7 @@ public class KerasSequentialModel extends KerasModel {
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend); importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
} }
/** /** Default constructor */
* Default constructor
*/
public KerasSequentialModel() { public KerasSequentialModel() {
super(); super();
} }
@ -174,13 +201,13 @@ public class KerasSequentialModel extends KerasModel {
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")"); "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder(); NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
NeuralNetConfiguration.builder();
if (optimizer != null) { if (optimizer != null) {
modelBuilder.updater(optimizer); modelBuilder.updater(optimizer);
} }
// don't forcibly override for keras import // don't forcibly override for keras import
modelBuilder.overrideNinUponBuild(false); modelBuilder.overrideNinUponBuild(false);
/* Add layers one at a time. */ /* Add layers one at a time. */
@ -192,7 +219,10 @@ public class KerasSequentialModel extends KerasModel {
if (nbInbound != 1) if (nbInbound != 1)
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found " "Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
+ nbInbound + " for layer " + layer.getName() + ")"); + nbInbound
+ " for layer "
+ layer.getName()
+ ")");
if (prevLayer != null) { if (prevLayer != null) {
InputType[] inputTypes = new InputType[1]; InputType[] inputTypes = new InputType[1];
InputPreProcessor preprocessor; InputPreProcessor preprocessor;
@ -207,35 +237,37 @@ public class KerasSequentialModel extends KerasModel {
if (preprocessor != null) { if (preprocessor != null) {
InputType outputType = preprocessor.getOutputType(inputTypes[0]); InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild()); layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
} }
else if (preprocessor != null) {
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
Map<Integer, InputPreProcessor> map = new HashMap<>();
map.put(layerIndex, preprocessor);
modelBuilder.inputPreProcessors(map);
} }
if (preprocessor != null)
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
} }
modelBuilder.layer(layerIndex++, layer.getLayer()); modelBuilder.layer(layerIndex++, layer.getLayer());
} else if (layer.getVertex() != null) } else if (layer.getVertex() != null)
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name " throw new InvalidKerasConfigurationException(
+ layer.getClassName() + ", layer name " + layer.getName() + ")"); "Cannot add vertex to NeuralNetConfiguration (class name "
+ layer.getClassName()
+ ", layer name "
+ layer.getName()
+ ")");
prevLayer = layer; prevLayer = layer;
} }
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */ /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
if (this.useTruncatedBPTT && this.truncatedBPTT > 0) if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
modelBuilder.backpropType(BackpropType.TruncatedBPTT) modelBuilder
.backpropType(BackpropType.TruncatedBPTT)
.tbpttFwdLength(truncatedBPTT) .tbpttFwdLength(truncatedBPTT)
.tbpttBackLength(truncatedBPTT); .tbpttBackLength(truncatedBPTT);
else else modelBuilder.backpropType(BackpropType.Standard);
modelBuilder.backpropType(BackpropType.Standard);
NeuralNetConfiguration build = modelBuilder.build(); NeuralNetConfiguration build = modelBuilder.build();
return build; return build;
} }

View File

@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
/** /**
* A fluent API to configure and create artificial neural networks * A fluent API to configure and create artificial neural networks
@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB
public class NN { public class NN {
public static NeuralNetConfigurationBuilder<?, ?> net() { public static NeuralNetConfigurationBuilder<?, ?> nn() {
return NeuralNetConfiguration.builder(); return NeuralNetConfiguration.builder();
} }
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
} }

View File

@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
@Getter @Setter @NonNull @lombok.Builder.Default @Getter @Setter @NonNull @lombok.Builder.Default
protected BackpropType backpropType = BackpropType.Standard; protected BackpropType backpropType = BackpropType.Standard;
@Getter @lombok.Builder.Default @Getter @Setter @Singular
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>(); protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
/** /**
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated) * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
* @param processor what to use to preProcess the data. * @param processor what to use to preProcess the data.
* @return builder pattern * @return builder pattern
*/ */
public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) { //public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>(); // inputPreProcessors$value.put(layer, processor);
inputPreProcessors$value.put(layer, processor); // inputPreProcessors$set = true;
inputPreProcessors$set = true; // return self();
return self(); // }
}
/** /**
* Set layer at index * Set layer at index

View File

@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.*; import com.fasterxml.jackson.databind.*;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@NonNull @NonNull
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType); InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
if (inputPreProcessor != null) { if (inputPreProcessor != null) {
inputPreProcessors = new HashMap<>(inputPreProcessors);
inputPreProcessors.put(i, inputPreProcessor); inputPreProcessors.put(i, inputPreProcessor);
} }
} }
@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
obj.getClass().getSimpleName()); obj.getClass().getSimpleName());
} }
}); });
// make sure the indexes are sequenced properly
AtomicInteger i = new AtomicInteger();
ret.forEach(obj -> {
obj.setIndex(i.getAndIncrement());
});
return ret; return ret;
} }

View File

@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
throw new IllegalStateException( throw new IllegalStateException(
"Invalid input for Convolution layer (layer name=\"" "Invalid input for Convolution layer (layer name=\""
+ getName() + getName()
+ "\"): Expected CNN input, got " + "\" at index '"+getIndex()+"') : Expected CNN input, got "
+ inputType); + inputType);
} }
@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param kernelSize kernel size * @param kernelSize kernel size
*/ */
public B kernelSize(int... kernelSize) { public B kernelSize(int... kernelSize) {
this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize"); //this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
this.kernelSize$value = kernelSize;
this.kernelSize$set = true; this.kernelSize$set = true;
return self(); return self();
} }
@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param stride kernel size * @param stride kernel size
*/ */
public B stride(int... stride) { public B stride(int... stride) {
this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride"); //this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
this.stride$value = stride;
this.stride$set = true; this.stride$set = true;
return self(); return self();
} }
@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param padding kernel size * @param padding kernel size
*/ */
public B padding(int... padding) { public B padding(int... padding) {
this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding"); //this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
this.padding$value = padding;
this.padding$set = true; this.padding$set = true;
return self(); return self();
} }
@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
* @param dilation kernel size * @param dilation kernel size
*/ */
public B dilation(int... dilation) { public B dilation(int... dilation) {
this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation"); //this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
this.dilation$value = dilation;
this.dilation$set = true; this.dilation$set = true;
return self(); return self();
} }

View File

@ -20,14 +20,19 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
import java.util.stream.IntStream;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import lombok.extern.jackson.Jacksonized; import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer; import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer {
boolean initializeParams, boolean initializeParams,
DataType networkDataType) { DataType networkDataType) {
setNetConfiguration(conf); setNetConfiguration(conf);
LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance(); runInheritance();
@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer {
getName(), getName(),
Deconvolution2DLayer.class); Deconvolution2DLayer.class);
} }
@Slf4j
private static final class Deconvolution2DBuilderImpl private static final class Deconvolution2DBuilderImpl
extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> { extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
public Deconvolution2D build() { public Deconvolution2D build() {
Deconvolution2D l = new Deconvolution2D(this); Deconvolution2D l = new Deconvolution2D(this);
if( l.getConvolutionMode() == ConvolutionMode.Same
&& IntStream.of(l.getPadding()).sum() != 0) {
log.warn("Invalid input for layer '{}'. "
+ "You cannot have a padding of {} when Convolution Mode is set to 'Same'."
+ " Padding will be ignored."
, l.getName(), l.getPadding());
}
/* strides * (input_size-1) + kernel_size - 2*padding */
//TODO: This is wrong, also depends on convolutionMode, etc ...
/*l.nOut = l.getStride()[0] * (l.getNIn()-1)
+ IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b)
- 2L * IntStream.of(l.getPadding()).sum();
*/
//l.nOut =264;
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }

View File

@ -62,6 +62,7 @@ public abstract class LayerConfiguration
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
@Getter @Setter protected String name; @Getter @Setter protected String name;
@Getter @Setter private int index;
@Getter @Setter protected List<LayerConstraint> allParamConstraints; @Getter @Setter protected List<LayerConstraint> allParamConstraints;
@Getter @Setter protected List<LayerConstraint> weightConstraints; @Getter @Setter protected List<LayerConstraint> weightConstraints;
@Getter @Setter protected List<LayerConstraint> biasConstraints; @Getter @Setter protected List<LayerConstraint> biasConstraints;
@ -72,6 +73,7 @@ public abstract class LayerConfiguration
/** The type of the layer, basically defines the base class and its properties */ /** The type of the layer, basically defines the base class and its properties */
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN; @Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
/** /**
* Number of parameters this layer has a result of its configuration * Number of parameters this layer has a result of its configuration
* @return number or parameters * @return number or parameters
@ -80,7 +82,6 @@ public abstract class LayerConfiguration
return initializer().numParams(this); return initializer().numParams(this);
} }
/** /**
* A reference to the neural net configuration. This field is excluded from json serialization as * A reference to the neural net configuration. This field is excluded from json serialization as
* well as from equals check to avoid circular referenced. * well as from equals check to avoid circular referenced.

View File

@ -53,13 +53,20 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
* @param numChannels the channels * @param numChannels the channels
*/ */
@JsonCreator @JsonCreator
public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight, public FeedForwardToCnnPreProcessor(
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { @JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) {
this.inputHeight = inputHeight; this.inputHeight = inputHeight;
this.inputWidth = inputWidth; this.inputWidth = inputWidth;
this.numChannels = numChannels; this.numChannels = numChannels;
} }
/**
* Reshape to a channels x rows x columns tensor
*
* @param inputHeight the columns
* @param inputWidth the rows
*/
public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) { public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
this.inputHeight = inputHeight; this.inputHeight = inputHeight;
this.inputWidth = inputWidth; this.inputWidth = inputWidth;
@ -69,18 +76,24 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
@Override @Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
this.shape = input.shape(); this.shape = input.shape();
if (input.rank() == 4) if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
if (input.columns() != inputWidth * inputHeight * numChannels) if (input.columns() != inputWidth * inputHeight * numChannels)
throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows " throw new IllegalArgumentException(
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels "Invalid input: expect output columns must be equal to rows "
+ " but was instead " + Arrays.toString(input.shape())); + inputHeight
+ " x columns "
+ inputWidth
+ " x channels "
+ numChannels
+ " but was instead "
+ Arrays.toString(input.shape()));
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, return workspaceMgr.leverageTo(
ArrayType.ACTIVATIONS,
input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth)); input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
} }
@ -100,13 +113,11 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape)); return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
} }
@Override @Override
public FeedForwardToCnnPreProcessor clone() { public FeedForwardToCnnPreProcessor clone() {
try { try {
FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone(); FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
if (clone.shape != null) if (clone.shape != null) clone.shape = clone.shape.clone();
clone.shape = clone.shape.clone();
return clone; return clone;
} catch (CloneNotSupportedException e) { } catch (CloneNotSupportedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
@ -121,26 +132,60 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType; InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
val expSize = inputHeight * inputWidth * numChannels; val expSize = inputHeight * inputWidth * numChannels;
if (c.getSize() != expSize) { if (c.getSize() != expSize) {
throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize throw new IllegalStateException(
+ " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got " "Invalid input: expected FeedForward input of size "
+ expSize
+ " = (d="
+ numChannels
+ " * w="
+ inputWidth
+ " * h="
+ inputHeight
+ "), got "
+ inputType); + inputType);
} }
return InputType.convolutional(inputHeight, inputWidth, numChannels); return InputType.convolutional(inputHeight, inputWidth, numChannels);
case CNN: case CNN:
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) { if (c2.getChannels() != numChannels
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels() || c2.getHeight() != inputHeight
+ "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels || c2.getWidth() != inputWidth) {
+ "," + inputHeight + "," + inputWidth + ")"); throw new IllegalStateException(
"Invalid input: Got CNN input type with (d,w,h)=("
+ c2.getChannels()
+ ","
+ c2.getWidth()
+ ","
+ c2.getHeight()
+ ") but expected ("
+ numChannels
+ ","
+ inputHeight
+ ","
+ inputWidth
+ ")");
} }
return c2; return c2;
case CNNFlat: case CNNFlat:
InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType; InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) { if (c3.getDepth() != numChannels
throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth() || c3.getHeight() != inputHeight
+ "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels || c3.getWidth() != inputWidth) {
+ "," + inputHeight + "," + inputWidth + ")"); throw new IllegalStateException(
"Invalid input: Got CNN input type with (d,w,h)=("
+ c3.getDepth()
+ ","
+ c3.getWidth()
+ ","
+ c3.getHeight()
+ ") but expected ("
+ numChannels
+ ","
+ inputHeight
+ ","
+ inputWidth
+ ")");
} }
return c3.getUnflattenedType(); return c3.getUnflattenedType();
default: default:
@ -149,10 +194,9 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
} }
@Override @Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, public Pair<INDArray, MaskState> feedForwardMaskArray(
int minibatchSize) { INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
// Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example) // Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
return new Pair<>(maskArray, currentMaskState); return new Pair<>(maskArray, currentMaskState);
} }
} }

View File

@ -369,7 +369,7 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
protected String layerId() { protected String layerId() {
String name = this.layerConfiguration.getName(); String name = this.layerConfiguration.getName();
return "(layer name: " return "(network: " + getNetConfiguration().getName() + " layer name: "
+ (name == null ? "\"\"" : name) + (name == null ? "\"\"" : name)
+ ", layer index: " + ", layer index: "
+ index + index

View File

@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int[] args = new int[] { int[] args = new int[] {
(int)kH, (int)kW, strides[0], strides[1], (int)kH, (int)kW, strides[0], strides[1],
pad[0], pad[1], dilation[0], dilation[1], sameMode, pad[0], pad[1], dilation[0], dilation[1], sameMode //,
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
}; };
INDArray delta; INDArray delta;
@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int[] args = new int[] { int[] args = new int[] {
kH, kW, strides[0], strides[1], kH, kW, strides[0], strides[1],
pad[0], pad[1], dilation[0], dilation[1], sameMode, pad[0], pad[1], dilation[0], dilation[1], sameMode //,
nchw ? 0 : 1 //0 = NCHW; 1 = NHWC //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
}; };
//DL4J Deconv weights: [inputDepth, outputDepth, kH, kW] //DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
} else { } else {
opInputs = new INDArray[]{input, weights}; opInputs = new INDArray[]{input, weights};
} }
/**
* 2D deconvolution implementation
*
* IntArgs:
* 0: kernel height
* 1: kernel width
* 2: stride height
* 3: stride width
* 4: padding height
* 5: padding width
* 6: dilation height
* 7: dilation width
* 8: same mode: 0 false, 1 true
*/
CustomOp op = DynamicCustomOp.builder("deconv2d") CustomOp op = DynamicCustomOp.builder("deconv2d")
.addInputs(opInputs) .addInputs(opInputs)
.addIntegerArguments(args) .addIntegerArguments(args)

View File

@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i); LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
layers[i] = layers[i] =
lc.instantiate( lc.instantiate(
lc.getNetConfiguration(), this.getNetConfiguration(),
trainingListeners, trainingListeners,
i, i,
paramsView, paramsView,

View File

@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
params.put(GAMMA, createGamma(conf, gammaView, initializeParams)); params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(GAMMA); conf.getNetConfiguration().addNetWideVariable(GAMMA);
conf.addVariable(GAMMA);
params.put(BETA, createBeta(conf, betaView, initializeParams)); params.put(BETA, createBeta(conf, betaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(BETA); conf.getNetConfiguration().addNetWideVariable(BETA);
conf.addVariable(BETA);
meanOffset = 2 * nOut; meanOffset = 2 * nOut;
} }
@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
params.put(GLOBAL_MEAN, globalMeanView); params.put(GLOBAL_MEAN, globalMeanView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN); conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
conf.addVariable(GLOBAL_MEAN);
if(layer.isUseLogStd()){ if(layer.isUseLogStd()){
params.put(GLOBAL_LOG_STD, globalVarView); params.put(GLOBAL_LOG_STD, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD); conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
conf.addVariable(GLOBAL_LOG_STD);
} else { } else {
params.put(GLOBAL_VAR, globalVarView); params.put(GLOBAL_VAR, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR); conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
conf.addVariable(GLOBAL_VAR);
} }
return params; return params;

View File

@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); conf.addVariable(WEIGHT_KEY);
conf.addVariable(BIAS_KEY);
} else { } else {
INDArray weightView = paramsView; INDArray weightView = paramsView;
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
conf.addVariable(WEIGHT_KEY);
} }
return params; return params;

View File

@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise #for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true
# When configured, Gradle will run in incubating parallel mode. # When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit # This option should only be used with decoupled projects. More details, visit

Binary file not shown.

16
gradlew vendored
View File

@ -116,6 +116,7 @@ esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM. # Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
@ -144,14 +145,12 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then
max*) max*)
MAX_FD=$( ulimit -H -n ) || MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit" warn "Could not query maximum file descriptor limit"
;;
esac esac
case $MAX_FD in #( case $MAX_FD in #(
'' | soft) :;; #( '' | soft) :;; #(
*) *)
ulimit -n "$MAX_FD" || ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD" warn "Could not set maximum file descriptor limit to $MAX_FD"
;;
esac esac
fi fi
@ -171,14 +170,12 @@ if "$cygwin" || "$msys"; then
JAVACMD=$( cygpath --unix "$JAVACMD" ) JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh # Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg; do for arg do
if if
case $arg in #( case $arg in #(
-*) false ;; # don't mess with options #( -*) false ;; # don't mess with options #(
/?*) /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath [ -e "$t" ] ;; #(
[ -e "$t" ]
;; #(
*) false ;; *) false ;;
esac esac
then then
@ -208,11 +205,6 @@ set -- \
org.gradle.wrapper.GradleWrapperMain \ org.gradle.wrapper.GradleWrapperMain \
"$@" "$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1; then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args. # Use "xargs" to parse quoted args.
# #
# With -n1 it outputs one arg per line, with the quotes and backslashes removed. # With -n1 it outputs one arg per line, with the quotes and backslashes removed.

10
gradlew.bat vendored
View File

@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1 %JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute if "%ERRORLEVEL%" == "0" goto execute
echo. echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
:end :end
@rem End local scope for the variables with windows NT shell @rem End local scope for the variables with windows NT shell
if %ERRORLEVEL% equ 0 goto mainEnd if "%ERRORLEVEL%"=="0" goto mainEnd
:fail :fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code! rem the _cmd.exe /c_ return code!
set EXIT_CODE=%ERRORLEVEL% if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
if %EXIT_CODE% equ 0 set EXIT_CODE=1 exit /b 1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd :mainEnd
if "%OS%"=="Windows_NT" endlocal if "%OS%"=="Windows_NT" endlocal