diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index 696f9241b..dc5898413 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -1,473 +1,261 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - package net.brutex.gan; +import static net.brutex.ai.dnn.api.NN.dense; + import java.awt.*; import java.awt.image.BufferedImage; import java.io.File; -import java.io.IOException; import java.util.Arrays; -import java.util.Random; -import java.util.UUID; -import javax.imageio.ImageIO; -import javax.swing.ImageIcon; -import javax.swing.JFrame; -import javax.swing.JLabel; -import javax.swing.JPanel; -import javax.swing.WindowConstants; -import lombok.extern.slf4j.Slf4j; +import javax.swing.*; import org.apache.commons.lang3.ArrayUtils; -import org.datavec.api.split.FileSplit; -import org.datavec.image.loader.NativeImageLoader; -import org.datavec.image.recordreader.ImageRecordReader; -import org.datavec.image.transform.ColorConversionTransform; -import org.datavec.image.transform.ImageTransform; -import org.datavec.image.transform.PipelineImageTransform; -import org.datavec.image.transform.ResizeImageTransform; -import org.datavec.image.transform.ShowImageTransform; -import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.distribution.Distribution; -import org.deeplearning4j.nn.conf.distribution.NormalDistribution; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.ActivationLayer; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.DropoutLayer; -import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; -import org.deeplearning4j.nn.conf.weightnoise.WeightNoise; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.optimize.listeners.PerformanceListener; -import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationLReLU; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; +import org.nd4j.linalg.lossfunctions.LossFunctions; -@Slf4j public class App { - private static final double LEARNING_RATE = 0.000002; - private static final double GRADIENT_THRESHOLD = 100.0; + private static final double LEARNING_RATE = 0.002; + private static final double GRADIENT_THRESHOLD = 100.0; + private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); + private static final int BATCHSIZE = 128; + private static JFrame frame; + private static JPanel panel; - private static final int X_DIM = 20 ; - private static final int Y_DIM = 20; - private static final int CHANNELS = 1; - private static final int batchSize = 1; - private static final int INPUT = 10; + private static LayerConfiguration[] genLayers() { + return new LayerConfiguration[] { + dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + dense().nIn(256).nOut(512).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + dense().nIn(512).nOut(1024).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + dense().nIn(1024).nOut(784).activation(Activation.TANH).build() + }; + } - private static final int OUTPUT_PER_PANEL = 16; + /** + * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image. + * + * @return config + */ + private static NeuralNetConfiguration generator() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .seed(42) + .updater(UPDATER) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(GRADIENT_THRESHOLD) + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .layersFromArray(genLayers()) + .name("generator") + .build(); - private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS; - private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); + return conf; + } - private static JFrame frame; - private static JFrame frame2; - private static JPanel panel; - private static JPanel panel2; + private static LayerConfiguration[] disLayers() { + return new LayerConfiguration[]{ + dense().nIn(784).nOut(1024).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + DropoutLayer.builder(1 - 0.5).build(), + dense().nIn(1024).nOut(512).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + DropoutLayer.builder(1 - 0.5).build(), + dense().nIn(512).nOut(256).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + DropoutLayer.builder(1 - 0.5).build(), + OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build() + }; + } - private static final String OUTPUT_DIR = "C:/temp/output/"; + private static NeuralNetConfiguration discriminator() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .seed(42) + .updater(UPDATER) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(GRADIENT_THRESHOLD) + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .layersFromArray(disLayers()) + .name("discriminator") + .build(); - private static LayerConfiguration[] genLayers() { - return new LayerConfiguration[] { - DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), - ActivationLayer.builder(Activation.LEAKYRELU).build(), + return conf; + } - DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), - ActivationLayer.builder(new ActivationLReLU(0.2)).build(), - DropoutLayer.builder(1 - 0.5).build(), - DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), - ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + private static NeuralNetConfiguration gan() { + LayerConfiguration[] genLayers = genLayers(); + LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream() + .map((layer) -> { + if (layer instanceof DenseLayer || layer instanceof OutputLayer) { + return FrozenLayerWithBackprop.builder(layer).build(); + } else { + return layer; + } + }).toArray(LayerConfiguration[]::new); + LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers); - DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build() - }; - } + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .seed(42) + .updater(UPDATER) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(GRADIENT_THRESHOLD) + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .layersFromArray(layers) + .name("GAN") + .build(); - /** - * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image. - * - * @return config - */ - private static NeuralNetConfiguration generator() { - NeuralNetConfiguration conf = NeuralNetConfiguration.builder() - .seed(42) - .updater(UPDATER) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .gradientNormalizationThreshold(GRADIENT_THRESHOLD) - //.weightInit(WeightInit.XAVIER) - .weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY) - .layersFromArray(genLayers()) - .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) - // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS)) - .build(); - ((NeuralNetConfiguration) conf).init(); + return conf; + } - return conf; - } + @Test + public void runTest() throws Exception { + App.main(null); + } + public static void main(String... args) throws Exception { + Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); - private static LayerConfiguration[] disLayers() { - return new LayerConfiguration[]{ - DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network - ActivationLayer.builder(new ActivationLReLU(0.2)).build(), - DropoutLayer.builder(1 - 0.5).build(), - DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC - ActivationLayer.builder(new ActivationLReLU(0.2)).build(), - DropoutLayer.builder(1 - 0.5).build(), - DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(), - ActivationLayer.builder(new ActivationLReLU(0.2)).build(), - DropoutLayer.builder(1 - 0.5).build(), - DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), - ActivationLayer.builder(new ActivationLReLU(0.2)).build(), - DropoutLayer.builder(1 - 0.5).build(), + MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42); - OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build() - }; - } + MultiLayerNetwork gen = new MultiLayerNetwork(generator()); + MultiLayerNetwork dis = new MultiLayerNetwork(discriminator()); + MultiLayerNetwork gan = new MultiLayerNetwork(gan()); + gen.init(); + dis.init(); + gan.init(); - private static NeuralNetConfiguration discriminator() { + copyParams(gen, dis, gan); - NeuralNetConfiguration conf = - NeuralNetConfiguration.builder() - .seed(42) - .updater(UPDATER) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .gradientNormalizationThreshold(GRADIENT_THRESHOLD) - .weightInit(WeightInit.XAVIER) - //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5))) - .weightNoise(null) - // .weightInitFn(new WeightInitXavier()) - // .activationFn(new ActivationIdentity()) - .activation(Activation.IDENTITY) - .layersFromArray(disLayers()) - .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) - .build(); - ((NeuralNetConfiguration) conf).init(); + gen.addTrainingListeners(new PerformanceListener(10, true)); + dis.addTrainingListeners(new PerformanceListener(10, true)); + gan.addTrainingListeners(new PerformanceListener(10, true)); - return conf; - } + trainData.reset(); - private static NeuralNetConfiguration gan() { - LayerConfiguration[] genLayers = genLayers(); - LayerConfiguration[] disLayers = Arrays.stream(disLayers()) - .map((layer) -> { - if (layer instanceof DenseLayer || layer instanceof OutputLayer) { - return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build(); - } else { - return layer; - } - }).toArray(LayerConfiguration[]::new); - LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers); + int j = 0; + for (int i = 0; i < 50; i++) { + while (trainData.hasNext()) { + j++; - NeuralNetConfiguration conf = NeuralNetConfiguration.builder() - .seed(42) - .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) - .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) - .gradientNormalizationThreshold( 100 ) - //.weightInitFn( new WeightInitXavier() ) //this is internal - .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5))) - .weightInit( WeightInit.XAVIER) - //.activationFn( new ActivationIdentity()) //this is internal - .activation( Activation.IDENTITY ) - .layersFromArray( layers ) - .inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) - .dataType(DataType.FLOAT) - .build(); -((NeuralNetConfiguration) conf).init(); - return conf; - } + // generate data + INDArray real = trainData.next().getFeatures().muli(2).subi(1); + int batchSize = (int) real.shape()[0]; + + INDArray fakeIn = Nd4j.rand(batchSize, 100); + INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); + + DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1)); + DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1)); + + DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet)); + + dis.fit(data); + dis.fit(data); + + // Update the discriminator in the GAN network + updateGan(gen, dis, gan); + + gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1))); - @Test - public void runTest() throws Exception { - if(! log.isDebugEnabled()) { - log.info("Logging is not set to DEBUG"); - } - else { - log.info("Logging is set to DEBUG"); - } - main(); - } + if (j % 10 == 1) { + System.out.println("Epoch " + i +" Iteration " + j + " Visualizing..."); + INDArray[] samples = new INDArray[9]; + DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1)); - public static void main(String... args) throws Exception { + for (int k = 0; k < 9; k++) { + INDArray input = fakeSet2.get(k).getFeatures(); + //samples[k] = gen.output(input, false); + samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input); - log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m "); - Nd4j.getMemoryManager().setAutoGcWindow(500); - - //MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45); - //FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS()); - FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS()); - - - ImageTransform transform = new ColorConversionTransform(new Random(42), 7 ); - - ImageTransform transform2 = new ShowImageTransform("Tester", 30); - ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM); - - ImageTransform tr = new PipelineImageTransform.Builder() - //.addImageTransform(transform) //convert to GREY SCALE - .addImageTransform(transform3) - //.addImageTransform(transform2) - .build(); - - ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS); - imageRecordReader.initialize(fileSplit, tr); - DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize ); - - MultiLayerNetwork gen = new MultiLayerNetwork(generator()); - MultiLayerNetwork dis = new MultiLayerNetwork(discriminator()); - MultiLayerNetwork gan = new MultiLayerNetwork(gan()); - gen.init(); log.debug("Generator network: {}", gen); - dis.init(); log.debug("Discriminator network: {}", dis); - gan.init(); log.info("Complete GAN network: {}", gan); - - - copyParams(gen, dis, gan); - - //gen.addTrainingListeners(new PerformanceListener(15, true, "GEN")); - dis.addTrainingListeners(new PerformanceListener(10, true, "DIS")); - gan.addTrainingListeners(new PerformanceListener(10, true, "GAN")); - //gan.addTrainingListeners(new ScoreToChartListener("gan")); - //dis.setListeners(new ScoreToChartListener("dis")); - - //System.out.println(gan.toString()); - //gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)); - - //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1))); - //trainData.reset(); - - int j = 0; - for (int i = 0; i < 51; i++) { //epoch - while (trainData.hasNext()) { - j++; - - DataSet next = trainData.next(); - // generate data - INDArray real = next.getFeatures();//.div(255f); - - //start next round if there are not enough images left to have a full batchsize dataset - if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) { - log.warn("Your total number of input images is not a multiple of {}, " - + "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE); - break; + } + visualize(samples); + } + } + trainData.reset(); + // Copy the GANs generator to gen. + //updateGen(gen, gan); } - //if(i%20 == 0) { - frame2 = visualize(new INDArray[]{real}, batchSize, - frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images - //} - real.divi(255f); - -// int batchSize = (int) real.shape()[0]; - - INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM); - //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise - - - INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); - fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM); - - //log.info("real has {} items.", real.length()); - DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1)); - DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1)); - - - DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet)); - - dis.fit(data); - //dis.fit(data); - - // Update the discriminator in the GAN network - updateGan(gen, dis, gan); - - //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1))); - gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1))); - - //Visualize and reporting - if (j % 10 == 1) { - System.out.println("Epoch " + i + " Iteration " + j + " Visualizing..."); - INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize]; - - - for (int k = 0; k < samples.length; k++) { - //INDArray input = fakeSet2.get(k).getFeatures(); - DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1)); - INDArray input = fakeSet2.get(k).getFeatures(); - input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here - - //samples[k] = gen.output(input, false); - samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input); - samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM); - //samples[k] = - samples[k].addi(1f).divi(2f).muli(255f); - - } - frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1 - } - } - if (trainData.resetSupported()) { - trainData.reset(); - } else { - log.error("Trainingdata {} does not support reset.", trainData.toString()); - } // Copy the GANs generator to gen. updateGen(gen, gan); + gen.save(new File("mnist-mlp-generator.dlj")); } - - - - } - - private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { - int genLayerCount = gen.getLayers().length; - for (int i = 0; i < gan.getLayers().length; i++) { - if (i < genLayerCount) { - if(gan.getLayer(i).getParams() != null) - gan.getLayer(i).setParams(gen.getLayer(i).getParams()); - } else { - if(gan.getLayer(i).getParams() != null) - gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams()); - } - } - } - - private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { - for (int i = 0; i < gen.getLayers().length; i++) { - gen.getLayer(i).setParams(gan.getLayer(i).getParams()); - } - } - - private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { - int genLayerCount = gen.getLayers().length; - for (int i = genLayerCount; i < gan.getLayers().length; i++) { - gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams()); - } - } - - private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) { - if (isOrig) { - frame.setTitle("Viz Original"); - } else { - frame.setTitle("Generated"); - } - - frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); - frame.setLayout(new BorderLayout()); - - JPanel panelx = new JPanel(); - - panelx.setLayout(new GridLayout(4, 4, 8, 8)); - for (INDArray sample : samples) { - for(int i = 0; i1) { - bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels - } else { - bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels - } - final int imageSize = X_DIM * Y_DIM; - final int offset = batchElement * imageSize; - int pxl = offset * CHANNELS; //where to start in the INDArray - - //Image in NCHW - channels first format - for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel - for (int y = 0; y < Y_DIM; y++) { // step through the columns x - for (int x = 0; x < X_DIM; x++) { //step through the rows y - if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl)); - bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl)); - pxl++; //next item in INDArray + private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { + int genLayerCount = gen.getLayers().length; + for (int i = 0; i < gan.getLayers().length; i++) { + if (i < genLayerCount) { + gen.getLayer(i).setParams(gan.getLayer(i).getParams()); + } else { + dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams()); + } } - } } - ImageIcon orig = new ImageIcon(bi); - Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT); - ImageIcon scaled = new ImageIcon(imageScaled); - if(! isOrig) saveImage(imageScaled, batchElement, isOrig); - return new JLabel(scaled); - } - - private static void saveImage(Image image, int batchElement, boolean isOrig) { - String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved - - try { - // Save the images to disk - saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png"); - - log.debug("Images saved successfully."); - } catch (IOException e) { - log.error("Error saving the images: {}", e.getMessage()); + private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { + for (int i = 0; i < gen.getLayers().length; i++) { + gen.getLayer(i).setParams(gan.getLayer(i).getParams()); + } } -} - private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException { - File directory = new File(outputDirectory); - if (!directory.exists()) { - directory.mkdir(); + + private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { + int genLayerCount = gen.getLayers().length; + for (int i = genLayerCount; i < gan.getLayers().length; i++) { + gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams()); + } + } + + private static void visualize(INDArray[] samples) { + if (frame == null) { + frame = new JFrame(); + frame.setTitle("Viz"); + frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + frame.setLayout(new BorderLayout()); + + panel = new JPanel(); + + panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8)); + frame.add(panel, BorderLayout.CENTER); + frame.setVisible(true); } - File outputFile = new File(directory, fileName); - ImageIO.write(imageToBufferedImage(image), "png", outputFile); - } + panel.removeAll(); - public static BufferedImage imageToBufferedImage(Image image) { - if (image instanceof BufferedImage) { - return (BufferedImage) image; + for (INDArray sample : samples) { + panel.add(getImage(sample)); } - // Create a buffered image with the same dimensions and transparency as the original image - BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB); - - // Draw the original image onto the buffered image - Graphics2D g2d = bufferedImage.createGraphics(); - g2d.drawImage(image, 0, 0, null); - g2d.dispose(); - - return bufferedImage; + frame.revalidate(); + frame.pack(); } + private static JLabel getImage(INDArray tensor) { + BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); + for (int i = 0; i < 784; i++) { + int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255); + bi.getRaster().setSample(i % 28, i / 28, 0, pixel); + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE); + + ImageIcon scaled = new ImageIcon(imageScaled); + + return new JLabel(scaled); + } } \ No newline at end of file diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App2.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App2.java new file mode 100644 index 000000000..f9f7aba58 --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App2.java @@ -0,0 +1,343 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.gan; + + + +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.util.*; +import java.util.List; +import javax.imageio.ImageIO; +import javax.swing.*; +import lombok.extern.slf4j.Slf4j; +import org.datavec.api.split.FileSplit; +import org.datavec.image.loader.NativeImageLoader; +import org.datavec.image.recordreader.ImageRecordReader; +import org.datavec.image.transform.*; +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.optimize.listeners.PerformanceListener; +import org.junit.jupiter.api.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; + +@Slf4j +public class App2 { + + final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS; + static final float COLORSPACE = 255f; + static final int DIMENSIONS = 28; + static final int CHANNELS = 1; + final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS; + final int OUTPUT_PER_PANEL = 10; + + final boolean BIAS = true; + + static final int BATCHSIZE=128; + + private JFrame frame2, frame; + static final String OUTPUT_DIR = "d:/out/"; + + final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1); + final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1); + + @Test + void runTest() throws IOException { + Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); + + MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200); + FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS()); + ImageTransform transform = new ColorConversionTransform(new Random(42), 7 ); + ImageTransform transform2 = new ShowImageTransform("Tester", 30); + ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS); + + ImageTransform tr = new PipelineImageTransform.Builder() + .addImageTransform(transform) //convert to GREY SCALE + .addImageTransform(transform3) + //.addImageTransform(transform2) + .build(); + + ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS); + imageRecordReader.initialize(fileSplit, tr); + DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE ); + trainData = new MnistDataSetIterator(BATCHSIZE, true, 42); + + MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator()); + MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator()); + + LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream() + .map((layer) -> { + if (layer instanceof DenseLayer || layer instanceof OutputLayer) { + return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build(); + } else { + return layer; + } + }).toArray(LayerConfiguration[]::new); + + NeuralNetConfiguration netConfiguration = + NeuralNetConfiguration.builder() + .name("GAN") + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(100) + .updater(App2Config.UPDATER) + .innerConfigurations(new ArrayList<>(List.of(App2Config.generator()))) + .layersFromList(new ArrayList<>(Arrays.asList(disLayers))) + // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS)) + // .inputPreProcessor(4, new CnnToFeedForwardPreProcessor()) + //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()) + // .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS)) + //.inputPreProcessor(2, new CnnToFeedForwardPreProcessor()) + //.dataType(DataType.FLOAT) + .build(); + + MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration ); + + dis.init(); log.debug("Discriminator network: {}", dis); + gen.init(); log.debug("Generator network: {}", gen); + gan.init(); log.debug("GAN network: {}", gan); + + + log.info("Generator Summary:\n{}", gen.summary()); + log.info("GAN Summary:\n{}", gan.summary()); + dis.addTrainingListeners(new PerformanceListener(10, true, "DIS")); + gen.addTrainingListeners(new PerformanceListener(10, true, "GEN")); + gan.addTrainingListeners(new PerformanceListener(10, true, "GAN")); + + int j = 0; + for (int i = 0; i < 51; i++) { //epoch + while (trainData.hasNext()) { + j++; + DataSet next = trainData.next(); + // generate data + INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f); + + //start next round if there are not enough images left to have a full batchsize dataset + if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) { + log.warn("Your total number of input images is not a multiple of {}, " + + "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE); + break; + } + + //if(i%20 == 0) { + + // frame2 = visualize(new INDArray[]{real}, BATCHSIZE, + // frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images + //} + //real.divi(255f); + +// int batchSize = (int) real.shape()[0]; + + //INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS); + //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise + INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT); + + INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); + // when generator has TANH as activation - value range is -1 to 1 + // when generator has SIGMOID, then range is 0 to 1 + fake.addi(1f).divi(2f); + + DataSet realSet = new DataSet(real, label_real); + DataSet fakeSet = new DataSet(fake, label_fake); + + DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet)); + + dis.fit(data); + dis.fit(data); + // Update the discriminator in the GAN network + updateGan(gen, dis, gan); + + gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake)); + + //Visualize and reporting + if (j % 10 == 1) { + System.out.println("Epoch " + i + " Iteration " + j + " Visualizing..."); + INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE]; + + + for (int k = 0; k < samples.length; k++) { + DataSet fakeSet2 = new DataSet(fakeIn, label_fake); + INDArray input = fakeSet2.get(k).getFeatures(); + + //input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images + input = input.reshape(1, App2Config.INPUT); + + //samples[k] = gen.output(input, false); + samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input); + samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS); + //samples[k] = + //samples[k].muli(255f); + + } + frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1 + } + } + + if (trainData.resetSupported()) { + trainData.reset(); + } else { + log.error("Trainingdata {} does not support reset.", trainData.toString()); + } + // Copy the GANs generator to gen. + updateGen(gen, gan); + log.info("Updated GAN's generator from gen."); + gen.save(new File("mnist-mlp-generator.dlj")); + } + } + + + + + + + + + + + + private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) { + if (isOrig) { + frame.setTitle("Viz Original"); + } else { + frame.setTitle("Generated"); + } + + frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + frame.setLayout(new BorderLayout()); + + JPanel panelx = new JPanel(); + + panelx.setLayout(new GridLayout(4, 4, 8, 8)); + for (INDArray sample : samples) { + for(int i = 0; i1) { + bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels + } else { + bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels + } + final int imageSize = DIMENSIONS * DIMENSIONS; + final int offset = batchElement * imageSize; + int pxl = offset * CHANNELS; //where to start in the INDArray + + //Image in NCHW - channels first format + for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel + for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x + for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y + float f_pxl = tensor.getFloat(pxl) * COLORSPACE; + if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl); + bi.getRaster().setSample(x, y, c, f_pxl); + pxl++; //next item in INDArray + } + } + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT); + ImageIcon scaled = new ImageIcon(imageScaled); + if(! isOrig) saveImage(imageScaled, batchElement, isOrig); + return new JLabel(scaled); + + } + + private static void saveImage(Image image, int batchElement, boolean isOrig) { + String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved + + try { + // Save the images to disk + saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png"); + + log.debug("Images saved successfully."); + } catch (IOException e) { + log.error("Error saving the images: {}", e.getMessage()); + } + } + private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException { + File directory = new File(outputDirectory); + if (!directory.exists()) { + directory.mkdir(); + } + + File outputFile = new File(directory, fileName); + ImageIO.write(imageToBufferedImage(image), "png", outputFile); + } + + public static BufferedImage imageToBufferedImage(Image image) { + if (image instanceof BufferedImage) { + return (BufferedImage) image; + } + + // Create a buffered image with the same dimensions and transparency as the original image + BufferedImage bufferedImage; + if (CHANNELS > 1) { + bufferedImage = + new BufferedImage( + image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB); + } else { + bufferedImage = + new BufferedImage( + image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY); + } + + // Draw the original image onto the buffered image + Graphics2D g2d = bufferedImage.createGraphics(); + g2d.drawImage(image, 0, 0, null); + g2d.dispose(); + + return bufferedImage; + } + private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { + for (int i = 0; i < gen.getLayers().length; i++) { + gen.getLayer(i).setParams(gan.getLayer(i).getParams()); + } + } + + private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { + int genLayerCount = gen.getLayers().length; + for (int i = genLayerCount; i < gan.getLayers().length; i++) { + gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams()); + } + } + +} diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App2Config.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App2Config.java new file mode 100644 index 000000000..607acd74b --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App2Config.java @@ -0,0 +1,176 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.gan; + +import static net.brutex.ai.dnn.api.NN.*; + +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +public class App2Config { + + public static final int INPUT = 100; + public static final int X_DIM = 28; + public static final int y_DIM = 28; + public static final int CHANNELS = 1; + public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build(); + + static LayerConfiguration[] genLayerConfig() { + return new LayerConfiguration[] { + /* + DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(), + ActivationLayer.builder().activation(Activation.RELU).build(), /* + Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS) + .kernelSize(2,2) + .stride(1,1) + .padding(0,0) + .convolutionMode(ConvolutionMode.Truncate) + .activation(Activation.RELU) + .hasBias(BIAS).build(), + //BatchNormalization.builder().nOut(CHANNELS).build(), + Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS) + .kernelSize(2,2) + .stride(2,2) + .padding(0,0) + .convolutionMode(ConvolutionMode.Truncate) + .activation(Activation.RELU) + .hasBias(BIAS).build(), + //BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(), + + + DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(), + ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(), + DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(), + ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(), + DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(), + ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(), + // DropoutLayer.builder(0.001).build(), + DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */ + + dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + dense().nIn(256).nOut(512).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + dense().nIn(512).nOut(1024).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + dense().nIn(1024).nOut(784).activation(Activation.TANH).build(), + + }; + } + + static LayerConfiguration[] disLayerConfig() { + return new LayerConfiguration[] {/* + Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS) + .build(), + Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS) + .build(), + ActivationLayer.builder().activation(Activation.LEAKYRELU).build(), + BatchNormalization.builder().build(), + OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SIGMOID) + .build() + + + dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(), + ActivationLayer.builder().activation(Activation.RELU).build(), + DropoutLayer.builder(0.5).build(), + + DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(), + ActivationLayer.builder().activation(Activation.RELU).build(), + DropoutLayer.builder(0.5).build(), + + DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(), + ActivationLayer.builder().activation(Activation.RELU).build(), + DropoutLayer.builder(0.5).build(), + + OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT) + .activation(Activation.SIGMOID) + .build() */ + dense().nIn(784).nOut(1024).hasBias(true).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + DropoutLayer.builder(1 - 0.5).build(), + dense().nIn(1024).nOut(512).hasBias(true).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + DropoutLayer.builder(1 - 0.5).build(), + dense().nIn(512).nOut(256).hasBias(true).build(), + ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + DropoutLayer.builder(1 - 0.5).build(), + OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build() + }; + } + + + static NeuralNetConfiguration generator() { + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() + .name("generator") + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(100) + .seed(42) + .updater(UPDATER) + .weightInit(WeightInit.XAVIER) + //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5))) + .weightNoise(null) + // .weightInitFn(new WeightInitXavier()) + // .activationFn(new ActivationIdentity()) + .activation(Activation.IDENTITY) + .layersFromArray(App2Config.genLayerConfig()) + // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS)) + //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()) + //.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS)) + //.inputPreProcessor(4, new CnnToFeedForwardPreProcessor()) + + .build(); + conf.init(); + return conf; + } + + static NeuralNetConfiguration discriminator() { + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() + .name("discriminator") + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(100) + .seed(42) + .updater(UPDATER) + .weightInit(WeightInit.XAVIER) + // .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5))) + .weightNoise(null) + // .weightInitFn(new WeightInitXavier()) + // .activationFn(new ActivationIdentity()) + .activation(Activation.IDENTITY) + .layersFromArray(disLayerConfig()) + //.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS)) + //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()) + //.dataType(DataType.FLOAT) + .build(); + conf.init(); + return conf; + } +} diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java index af22b0a1b..a9a885b74 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java @@ -43,223 +43,255 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.impo @Slf4j public class KerasSequentialModel extends KerasModel { - /** - * (Recommended) Builder-pattern constructor for Sequential model. - * - * @param modelBuilder builder object - * @throws IOException I/O exception - * @throws InvalidKerasConfigurationException Invalid Keras configuration - * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration - */ - public KerasSequentialModel(KerasModelBuilder modelBuilder) - throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { - this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(), - modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(), - modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape()); + /** + * (Recommended) Builder-pattern constructor for Sequential model. + * + * @param modelBuilder builder object + * @throws IOException I/O exception + * @throws InvalidKerasConfigurationException Invalid Keras configuration + * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration + */ + public KerasSequentialModel(KerasModelBuilder modelBuilder) + throws UnsupportedKerasConfigurationException, + IOException, + InvalidKerasConfigurationException { + this( + modelBuilder.getModelJson(), + modelBuilder.getModelYaml(), + modelBuilder.getWeightsArchive(), + modelBuilder.getWeightsRoot(), + modelBuilder.getTrainingJson(), + modelBuilder.getTrainingArchive(), + modelBuilder.isEnforceTrainingConfig(), + modelBuilder.getInputShape()); + } + + /** + * (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML), + * training configuration (JSON), weights, and "training mode" boolean indicator. When built in + * training mode, certain unsupported configurations (e.g., unknown regularizers) will throw + * Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be + * otherwise ignored. + * + * @param modelJson model configuration JSON string + * @param modelYaml model configuration YAML string + * @param trainingJson training configuration JSON string + * @throws IOException I/O exception + */ + public KerasSequentialModel( + String modelJson, + String modelYaml, + Hdf5Archive weightsArchive, + String weightsRoot, + String trainingJson, + Hdf5Archive trainingArchive, + boolean enforceTrainingConfig, + int[] inputShape) + throws IOException, + InvalidKerasConfigurationException, + UnsupportedKerasConfigurationException { + + Map modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml); + this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config); + this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config); + this.enforceTrainingConfig = enforceTrainingConfig; + + /* Determine model configuration type. */ + if (!modelConfig.containsKey(config.getFieldClassName())) + throw new InvalidKerasConfigurationException( + "Could not determine Keras model class (no " + + config.getFieldClassName() + + " field found)"); + this.className = (String) modelConfig.get(config.getFieldClassName()); + if (!this.className.equals(config.getFieldClassNameSequential())) + throw new InvalidKerasConfigurationException( + "Model class name must be " + + config.getFieldClassNameSequential() + + " (found " + + this.className + + ")"); + + /* Process layer configurations. */ + if (!modelConfig.containsKey(config.getModelFieldConfig())) + throw new InvalidKerasConfigurationException( + "Could not find layer configurations (no " + + config.getModelFieldConfig() + + " field found)"); + + // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. + // For consistency + // "config" is now an object containing a "name" and "layers", the latter contain the same data + // as before. + // This change only affects Sequential models. + List layerList; + try { + layerList = (List) modelConfig.get(config.getModelFieldConfig()); + } catch (Exception e) { + HashMap layerMap = (HashMap) modelConfig.get(config.getModelFieldConfig()); + layerList = (List) layerMap.get("layers"); } - /** - * (Not recommended) Constructor for Sequential model from model configuration - * (JSON or YAML), training configuration (JSON), weights, and "training mode" - * boolean indicator. When built in training mode, certain unsupported configurations - * (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these - * will generate warnings but will be otherwise ignored. - * - * @param modelJson model configuration JSON string - * @param modelYaml model configuration YAML string - * @param trainingJson training configuration JSON string - * @throws IOException I/O exception - */ - public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, - String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig, - int[] inputShape) - throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + Pair, List> layerPair = prepareLayers(layerList); + this.layers = layerPair.getFirst(); + this.layersOrdered = layerPair.getSecond(); - Map modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml); - this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config); - this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config); - this.enforceTrainingConfig = enforceTrainingConfig; + KerasLayer inputLayer; + if (this.layersOrdered.get(0) instanceof KerasInput) { + inputLayer = this.layersOrdered.get(0); + } else { + /* Add placeholder input layer and update lists of input and output layers. */ + int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape(); + Preconditions.checkState( + ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!"); + inputLayer = new KerasInput("input1", firstLayerInputShape); + inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder()); + this.layers.put(inputLayer.getName(), inputLayer); + this.layersOrdered.add(0, inputLayer); + } + this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName())); + this.outputLayerNames = + new ArrayList<>( + Collections.singletonList( + this.layersOrdered.get(this.layersOrdered.size() - 1).getName())); - /* Determine model configuration type. */ - if (!modelConfig.containsKey(config.getFieldClassName())) - throw new InvalidKerasConfigurationException( - "Could not determine Keras model class (no " + config.getFieldClassName() + " field found)"); - this.className = (String) modelConfig.get(config.getFieldClassName()); - if (!this.className.equals(config.getFieldClassNameSequential())) - throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential() - + " (found " + this.className + ")"); - - /* Process layer configurations. */ - if (!modelConfig.containsKey(config.getModelFieldConfig())) - throw new InvalidKerasConfigurationException( - "Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)"); - - // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency - // "config" is now an object containing a "name" and "layers", the latter contain the same data as before. - // This change only affects Sequential models. - List layerList; - try { - layerList = (List) modelConfig.get(config.getModelFieldConfig()); - } catch (Exception e) { - HashMap layerMap = (HashMap) modelConfig.get(config.getModelFieldConfig()); - layerList = (List) layerMap.get("layers"); - } - - Pair, List> layerPair = - prepareLayers(layerList); - this.layers = layerPair.getFirst(); - this.layersOrdered = layerPair.getSecond(); - - KerasLayer inputLayer; - if (this.layersOrdered.get(0) instanceof KerasInput) { - inputLayer = this.layersOrdered.get(0); - } else { - /* Add placeholder input layer and update lists of input and output layers. */ - int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape(); - Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!"); - inputLayer = new KerasInput("input1", firstLayerInputShape); - inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder()); - this.layers.put(inputLayer.getName(), inputLayer); - this.layersOrdered.add(0, inputLayer); - } - this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName())); - this.outputLayerNames = new ArrayList<>( - Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName())); - - /* Update each layer's inbound layer list to include (only) previous layer. */ - KerasLayer prevLayer = null; - for (KerasLayer layer : this.layersOrdered) { - if (prevLayer != null) - layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName())); - prevLayer = layer; - } - - /* Import training configuration. */ - if (enforceTrainingConfig) { - if (trainingJson != null) - importTrainingConfiguration(trainingJson); - else log.warn("If enforceTrainingConfig is true, a training " + - "configuration object has to be provided. Usually the only practical way to do this is to store" + - " your keras model with `model.save('model_path.h5'. If you store model config and weights" + - " separately no training configuration is attached."); - } - - this.outputTypes = inferOutputTypes(inputShape); - - if (weightsArchive != null) - importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend); + /* Update each layer's inbound layer list to include (only) previous layer. */ + KerasLayer prevLayer = null; + for (KerasLayer layer : this.layersOrdered) { + if (prevLayer != null) + layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName())); + prevLayer = layer; } - /** - * Default constructor - */ - public KerasSequentialModel() { - super(); + /* Import training configuration. */ + if (enforceTrainingConfig) { + if (trainingJson != null) importTrainingConfiguration(trainingJson); + else + log.warn( + "If enforceTrainingConfig is true, a training " + + "configuration object has to be provided. Usually the only practical way to do this is to store" + + " your keras model with `model.save('model_path.h5'. If you store model config and weights" + + " separately no training configuration is attached."); } - /** - * Configure a NeuralNetConfiguration from this Keras Sequential model configuration. - * - * @return NeuralNetConfiguration - */ - public NeuralNetConfiguration getNeuralNetConfiguration() - throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - if (!this.className.equals(config.getFieldClassNameSequential())) - throw new InvalidKerasConfigurationException( - "Keras model class name " + this.className + " incompatible with MultiLayerNetwork"); - if (this.inputLayerNames.size() != 1) - throw new InvalidKerasConfigurationException( - "MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")"); - if (this.outputLayerNames.size() != 1) - throw new InvalidKerasConfigurationException( - "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")"); + this.outputTypes = inferOutputTypes(inputShape); - NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder(); + if (weightsArchive != null) + importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend); + } - if (optimizer != null) { - modelBuilder.updater(optimizer); + /** Default constructor */ + public KerasSequentialModel() { + super(); + } + + /** + * Configure a NeuralNetConfiguration from this Keras Sequential model configuration. + * + * @return NeuralNetConfiguration + */ + public NeuralNetConfiguration getNeuralNetConfiguration() + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + if (!this.className.equals(config.getFieldClassNameSequential())) + throw new InvalidKerasConfigurationException( + "Keras model class name " + this.className + " incompatible with MultiLayerNetwork"); + if (this.inputLayerNames.size() != 1) + throw new InvalidKerasConfigurationException( + "MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")"); + if (this.outputLayerNames.size() != 1) + throw new InvalidKerasConfigurationException( + "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")"); + + NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = + NeuralNetConfiguration.builder(); + + if (optimizer != null) { + modelBuilder.updater(optimizer); + } + + // don't forcibly override for keras import + modelBuilder.overrideNinUponBuild(false); + /* Add layers one at a time. */ + KerasLayer prevLayer = null; + int layerIndex = 0; + for (KerasLayer layer : this.layersOrdered) { + if (layer.isLayer()) { + int nbInbound = layer.getInboundLayerNames().size(); + if (nbInbound != 1) + throw new InvalidKerasConfigurationException( + "Layers in NeuralNetConfiguration must have exactly one inbound layer (found " + + nbInbound + + " for layer " + + layer.getName() + + ")"); + if (prevLayer != null) { + InputType[] inputTypes = new InputType[1]; + InputPreProcessor preprocessor; + if (prevLayer.isInputPreProcessor()) { + inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0)); + preprocessor = prevLayer.getInputPreprocessor(inputTypes); + InputType outputType = preprocessor.getOutputType(inputTypes[0]); + layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild()); + } else { + inputTypes[0] = this.outputTypes.get(prevLayer.getName()); + preprocessor = layer.getInputPreprocessor(inputTypes); + if (preprocessor != null) { + InputType outputType = preprocessor.getOutputType(inputTypes[0]); + layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild()); + } else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild()); + } + if (preprocessor != null) { + + Map map = new HashMap<>(); + map.put(layerIndex, preprocessor); + modelBuilder.inputPreProcessors(map); + } } - - //don't forcibly override for keras import - modelBuilder.overrideNinUponBuild(false); - /* Add layers one at a time. */ - KerasLayer prevLayer = null; - int layerIndex = 0; - for (KerasLayer layer : this.layersOrdered) { - if (layer.isLayer()) { - int nbInbound = layer.getInboundLayerNames().size(); - if (nbInbound != 1) - throw new InvalidKerasConfigurationException( - "Layers in NeuralNetConfiguration must have exactly one inbound layer (found " - + nbInbound + " for layer " + layer.getName() + ")"); - if (prevLayer != null) { - InputType[] inputTypes = new InputType[1]; - InputPreProcessor preprocessor; - if (prevLayer.isInputPreProcessor()) { - inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0)); - preprocessor = prevLayer.getInputPreprocessor(inputTypes); - InputType outputType = preprocessor.getOutputType(inputTypes[0]); - layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild()); - } else { - inputTypes[0] = this.outputTypes.get(prevLayer.getName()); - preprocessor = layer.getInputPreprocessor(inputTypes); - if(preprocessor != null) { - InputType outputType = preprocessor.getOutputType(inputTypes[0]); - layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild()); - } - else - layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild()); - - } - if (preprocessor != null) - modelBuilder.inputPreProcessor(layerIndex, preprocessor); - - - } - - modelBuilder.layer(layerIndex++, layer.getLayer()); - } else if (layer.getVertex() != null) - throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name " - + layer.getClassName() + ", layer name " + layer.getName() + ")"); - prevLayer = layer; - } - - /* Whether to use standard backprop (or BPTT) or truncated BPTT. */ - if (this.useTruncatedBPTT && this.truncatedBPTT > 0) - modelBuilder.backpropType(BackpropType.TruncatedBPTT) - .tbpttFwdLength(truncatedBPTT) - .tbpttBackLength(truncatedBPTT); - else - modelBuilder.backpropType(BackpropType.Standard); - - NeuralNetConfiguration build = modelBuilder.build(); - - - return build; + modelBuilder.layer(layerIndex++, layer.getLayer()); + } else if (layer.getVertex() != null) + throw new InvalidKerasConfigurationException( + "Cannot add vertex to NeuralNetConfiguration (class name " + + layer.getClassName() + + ", layer name " + + layer.getName() + + ")"); + prevLayer = layer; } - /** - * Build a MultiLayerNetwork from this Keras Sequential model configuration. - * - * @return MultiLayerNetwork - */ - public MultiLayerNetwork getMultiLayerNetwork() - throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - return getMultiLayerNetwork(true); - } + /* Whether to use standard backprop (or BPTT) or truncated BPTT. */ + if (this.useTruncatedBPTT && this.truncatedBPTT > 0) + modelBuilder + .backpropType(BackpropType.TruncatedBPTT) + .tbpttFwdLength(truncatedBPTT) + .tbpttBackLength(truncatedBPTT); + else modelBuilder.backpropType(BackpropType.Standard); - /** - * Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights. - * - * @return MultiLayerNetwork - */ - public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights) - throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration()); - model.init(); - if (importWeights) - model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers); - return model; - } + NeuralNetConfiguration build = modelBuilder.build(); + + return build; + } + + /** + * Build a MultiLayerNetwork from this Keras Sequential model configuration. + * + * @return MultiLayerNetwork + */ + public MultiLayerNetwork getMultiLayerNetwork() + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + return getMultiLayerNetwork(true); + } + + /** + * Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights. + * + * @return MultiLayerNetwork + */ + public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration()); + model.init(); + if (importWeights) + model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers); + return model; + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java index 06e73dcf9..683a83083 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java @@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; +import org.deeplearning4j.nn.conf.layers.DenseLayer; /** * A fluent API to configure and create artificial neural networks @@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB public class NN { - public static NeuralNetConfigurationBuilder net() { + public static NeuralNetConfigurationBuilder nn() { return NeuralNetConfiguration.builder(); } + public static DenseLayer.DenseLayerBuilder dense() { return DenseLayer.builder(); } + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java index b69eb174f..ae4a8604b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java @@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor @Getter @Setter @NonNull @lombok.Builder.Default protected BackpropType backpropType = BackpropType.Standard; - @Getter @lombok.Builder.Default + @Getter @Setter @Singular protected Map inputPreProcessors = new HashMap<>(); /** * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated) @@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor * @param processor what to use to preProcess the data. * @return builder pattern */ - public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) { - if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>(); - inputPreProcessors$value.put(layer, processor); - inputPreProcessors$set = true; - return self(); - } + //public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) { + // inputPreProcessors$value.put(layer, processor); + // inputPreProcessors$set = true; + // return self(); + // } /** * Set layer at index diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index fcdb56125..48dd1f370 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.*; import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import lombok.*; import lombok.experimental.SuperBuilder; @@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { @NonNull InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType); if (inputPreProcessor != null) { + inputPreProcessors = new HashMap<>(inputPreProcessors); inputPreProcessors.put(i, inputPreProcessor); } } @@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { obj.getClass().getSimpleName()); } }); + // make sure the indexes are sequenced properly + AtomicInteger i = new AtomicInteger(); + ret.forEach(obj -> { + obj.setIndex(i.getAndIncrement()); + }); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 0e76f5776..5908bc966 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer { throw new IllegalStateException( "Invalid input for Convolution layer (layer name=\"" + getName() - + "\"): Expected CNN input, got " + + "\" at index '"+getIndex()+"') : Expected CNN input, got " + inputType); } @@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer { * @param kernelSize kernel size */ public B kernelSize(int... kernelSize) { - this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize"); + //this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize"); + this.kernelSize$value = kernelSize; this.kernelSize$set = true; return self(); } @@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer { * @param stride kernel size */ public B stride(int... stride) { - this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride"); + //this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride"); + this.stride$value = stride; this.stride$set = true; return self(); } @@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer { * @param padding kernel size */ public B padding(int... padding) { - this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding"); + //this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding"); + this.padding$value = padding; this.padding$set = true; return self(); } @@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer { * @param dilation kernel size */ public B dilation(int... dilation) { - this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation"); + //this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation"); + this.dilation$value = dilation; this.dilation$set = true; return self(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index 62fab4f7f..c8374c646 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -20,14 +20,19 @@ package org.deeplearning4j.nn.conf.layers; +import java.util.Arrays; import java.util.Collection; import java.util.Map; +import java.util.stream.IntStream; + import lombok.*; import lombok.experimental.SuperBuilder; import lombok.extern.jackson.Jacksonized; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer; @@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer { boolean initializeParams, DataType networkDataType) { setNetConfiguration(conf); + + LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut()); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); runInheritance(); @@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer { getName(), Deconvolution2DLayer.class); } - +@Slf4j private static final class Deconvolution2DBuilderImpl extends Deconvolution2DBuilder { public Deconvolution2D build() { Deconvolution2D l = new Deconvolution2D(this); + if( l.getConvolutionMode() == ConvolutionMode.Same + && IntStream.of(l.getPadding()).sum() != 0) { + log.warn("Invalid input for layer '{}'. " + + "You cannot have a padding of {} when Convolution Mode is set to 'Same'." + + " Padding will be ignored." + , l.getName(), l.getPadding()); + } + /* strides * (input_size-1) + kernel_size - 2*padding */ + //TODO: This is wrong, also depends on convolutionMode, etc ... + /*l.nOut = l.getStride()[0] * (l.getNIn()-1) + + IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b) + - 2L * IntStream.of(l.getPadding()).sum(); + */ + //l.nOut =264; l.initializeConstraints(); return l; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java index 5c04fa32c..11a1b6f30 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java @@ -62,6 +62,7 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration @Getter @Setter protected String name; + @Getter @Setter private int index; @Getter @Setter protected List allParamConstraints; @Getter @Setter protected List weightConstraints; @Getter @Setter protected List biasConstraints; @@ -72,6 +73,7 @@ public abstract class LayerConfiguration /** The type of the layer, basically defines the base class and its properties */ @Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN; + /** * Number of parameters this layer has a result of its configuration * @return number or parameters @@ -80,7 +82,6 @@ public abstract class LayerConfiguration return initializer().numParams(this); } - /** * A reference to the neural net configuration. This field is excluded from json serialization as * well as from equals check to avoid circular referenced. diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java index 810ae513a..1bf3c6392 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java @@ -37,122 +37,166 @@ import org.nd4j.linalg.api.shape.Shape; @Data @EqualsAndHashCode(exclude = {"shape"}) public class FeedForwardToCnnPreProcessor implements InputPreProcessor { - private long inputHeight; - private long inputWidth; - private long numChannels; + private long inputHeight; + private long inputWidth; + private long numChannels; - @Getter(AccessLevel.NONE) - @Setter(AccessLevel.NONE) - private long[] shape; + @Getter(AccessLevel.NONE) + @Setter(AccessLevel.NONE) + private long[] shape; - /** - * Reshape to a channels x rows x columns tensor - * - * @param inputHeight the columns - * @param inputWidth the rows - * @param numChannels the channels - */ - @JsonCreator - public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight, - @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { - this.inputHeight = inputHeight; - this.inputWidth = inputWidth; - this.numChannels = numChannels; + /** + * Reshape to a channels x rows x columns tensor + * + * @param inputHeight the columns + * @param inputWidth the rows + * @param numChannels the channels + */ + @JsonCreator + public FeedForwardToCnnPreProcessor( + @JsonProperty("inputHeight") long inputHeight, + @JsonProperty("inputWidth") long inputWidth, + @JsonProperty("numChannels") long numChannels) { + this.inputHeight = inputHeight; + this.inputWidth = inputWidth; + this.numChannels = numChannels; + } + /** + * Reshape to a channels x rows x columns tensor + * + * @param inputHeight the columns + * @param inputWidth the rows + */ + public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) { + this.inputHeight = inputHeight; + this.inputWidth = inputWidth; + this.numChannels = 1; + } + + @Override + public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + this.shape = input.shape(); + if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input); + + if (input.columns() != inputWidth * inputHeight * numChannels) + throw new IllegalArgumentException( + "Invalid input: expect output columns must be equal to rows " + + inputHeight + + " x columns " + + inputWidth + + " x channels " + + numChannels + + " but was instead " + + Arrays.toString(input.shape())); + + if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) + input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); + + return workspaceMgr.leverageTo( + ArrayType.ACTIVATIONS, + input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth)); + } + + @Override + // return 4 dimensions + public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons)) + epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c'); + + if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) { + if (epsilons.rank() == 2) + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); // should never happen + + return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); } - public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) { - this.inputHeight = inputHeight; - this.inputWidth = inputWidth; - this.numChannels = 1; + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape)); + } + + @Override + public FeedForwardToCnnPreProcessor clone() { + try { + FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone(); + if (clone.shape != null) clone.shape = clone.shape.clone(); + return clone; + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); } + } - @Override - public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { - this.shape = input.shape(); - if (input.rank() == 4) - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input); + @Override + public InputType getOutputType(InputType inputType) { - if (input.columns() != inputWidth * inputHeight * numChannels) - throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows " - + inputHeight + " x columns " + inputWidth + " x channels " + numChannels - + " but was instead " + Arrays.toString(input.shape())); - - if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) - input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); - - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, - input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth)); - } - - @Override - // return 4 dimensions - public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { - if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons)) - epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c'); - - if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) { - if (epsilons.rank() == 2) - return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); //should never happen - - return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); + switch (inputType.getType()) { + case FF: + InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType; + val expSize = inputHeight * inputWidth * numChannels; + if (c.getSize() != expSize) { + throw new IllegalStateException( + "Invalid input: expected FeedForward input of size " + + expSize + + " = (d=" + + numChannels + + " * w=" + + inputWidth + + " * h=" + + inputHeight + + "), got " + + inputType); } + return InputType.convolutional(inputHeight, inputWidth, numChannels); + case CNN: + InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType; - return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape)); - } - - - @Override - public FeedForwardToCnnPreProcessor clone() { - try { - FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone(); - if (clone.shape != null) - clone.shape = clone.shape.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); + if (c2.getChannels() != numChannels + || c2.getHeight() != inputHeight + || c2.getWidth() != inputWidth) { + throw new IllegalStateException( + "Invalid input: Got CNN input type with (d,w,h)=(" + + c2.getChannels() + + "," + + c2.getWidth() + + "," + + c2.getHeight() + + ") but expected (" + + numChannels + + "," + + inputHeight + + "," + + inputWidth + + ")"); } - } - - @Override - public InputType getOutputType(InputType inputType) { - - switch (inputType.getType()) { - case FF: - InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType; - val expSize = inputHeight * inputWidth * numChannels; - if (c.getSize() != expSize) { - throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize - + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got " - + inputType); - } - return InputType.convolutional(inputHeight, inputWidth, numChannels); - case CNN: - InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType; - - if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) { - throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels() - + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels - + "," + inputHeight + "," + inputWidth + ")"); - } - return c2; - case CNNFlat: - InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType; - if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) { - throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth() - + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels - + "," + inputHeight + "," + inputWidth + ")"); - } - return c3.getUnflattenedType(); - default: - throw new IllegalStateException("Invalid input type: got " + inputType); + return c2; + case CNNFlat: + InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType; + if (c3.getDepth() != numChannels + || c3.getHeight() != inputHeight + || c3.getWidth() != inputWidth) { + throw new IllegalStateException( + "Invalid input: Got CNN input type with (d,w,h)=(" + + c3.getDepth() + + "," + + c3.getWidth() + + "," + + c3.getHeight() + + ") but expected (" + + numChannels + + "," + + inputHeight + + "," + + inputWidth + + ")"); } + return c3.getUnflattenedType(); + default: + throw new IllegalStateException("Invalid input type: got " + inputType); } + } - @Override - public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, - int minibatchSize) { - //Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example) - return new Pair<>(maskArray, currentMaskState); - } - + @Override + public Pair feedForwardMaskArray( + INDArray maskArray, MaskState currentMaskState, int minibatchSize) { + // Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example) + return new Pair<>(maskArray, currentMaskState); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index 006b87250..707d995d9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -369,7 +369,7 @@ public abstract class AbstractLayer impl protected String layerId() { String name = this.layerConfiguration.getName(); - return "(layer name: " + return "(network: " + getNetConfiguration().getName() + " layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + index diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java index 0761a1189..3b3032151 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java @@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer { int[] args = new int[] { (int)kH, (int)kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode, - nchw ? 0 : 1 //0 = NCHW; 1 = NHWC + pad[0], pad[1], dilation[0], dilation[1], sameMode //, + //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC }; INDArray delta; @@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer { int[] args = new int[] { kH, kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode, - nchw ? 0 : 1 //0 = NCHW; 1 = NHWC + pad[0], pad[1], dilation[0], dilation[1], sameMode //, + //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC }; //DL4J Deconv weights: [inputDepth, outputDepth, kH, kW] @@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer { } else { opInputs = new INDArray[]{input, weights}; } + /** + * 2D deconvolution implementation + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ CustomOp op = DynamicCustomOp.builder("deconv2d") .addInputs(opInputs) .addIntegerArguments(args) diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 7cc08a62d..9ce05fb9f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i); layers[i] = lc.instantiate( - lc.getNetConfiguration(), + this.getNetConfiguration(), trainingListeners, i, paramsView, diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java index 3538960d1..8769b3e39 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java @@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer params.put(GAMMA, createGamma(conf, gammaView, initializeParams)); conf.getNetConfiguration().addNetWideVariable(GAMMA); + conf.addVariable(GAMMA); params.put(BETA, createBeta(conf, betaView, initializeParams)); conf.getNetConfiguration().addNetWideVariable(BETA); + conf.addVariable(BETA); meanOffset = 2 * nOut; } @@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer params.put(GLOBAL_MEAN, globalMeanView); conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN); + conf.addVariable(GLOBAL_MEAN); if(layer.isUseLogStd()){ params.put(GLOBAL_LOG_STD, globalVarView); conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD); + conf.addVariable(GLOBAL_LOG_STD); } else { params.put(GLOBAL_VAR, globalVarView); conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR); + conf.addVariable(GLOBAL_VAR); } return params; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java index a0d9bea82..9ad41df57 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java @@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer { params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); - conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); + conf.addVariable(WEIGHT_KEY); + conf.addVariable(BIAS_KEY); } else { INDArray weightView = paramsView; params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); + conf.addVariable(WEIGHT_KEY); } return params; diff --git a/gradle.properties b/gradle.properties index ef0384eee..c6ceae50d 100644 --- a/gradle.properties +++ b/gradle.properties @@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true #for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log - +#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 249e5832f..41d9927a4 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradlew b/gradlew index 64c63a782..1b6c78733 100644 --- a/gradlew +++ b/gradlew @@ -69,18 +69,18 @@ app_path=$0 # Need this for daisy-chained symlinks. while - APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path - [ -h "$app_path" ] + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] do - ls=$(ls -ld "$app_path") - link=${ls#*' -> '} - case $link in #( - /*) app_path=$link ;; #( - *) app_path=$APP_HOME$link ;; - esac + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit APP_NAME="Gradle" APP_BASE_NAME=${0##*/} @@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum -warn() { - echo "$*" +warn () { + echo "$*" } >&2 -die() { - echo - echo "$*" - echo - exit 1 +die () { + echo + echo "$*" + echo + exit 1 } >&2 # OS specific support (must be 'true' or 'false'). @@ -107,52 +107,51 @@ cygwin=false msys=false darwin=false nonstop=false -case "$(uname)" in #( -CYGWIN*) cygwin=true ;; #( -Darwin*) darwin=true ;; #( -MSYS* | MINGW*) msys=true ;; #( -NONSTOP*) nonstop=true ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + # Determine the Java command to use to start the JVM. -if [ -n "$JAVA_HOME" ]; then - if [ -x "$JAVA_HOME/jre/sh/java" ]; then - # IBM's JDK on AIX uses strange locations for the executables - JAVACMD=$JAVA_HOME/jre/sh/java - else - JAVACMD=$JAVA_HOME/bin/java - fi - if [ ! -x "$JAVACMD" ]; then - die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME Please set the JAVA_HOME variable in your environment to match the location of your Java installation." - fi + fi else - JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + JAVACMD=java + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi # Increase the maximum file descriptors if we can. -if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then - case $MAX_FD in #( - max*) - MAX_FD=$(ulimit -H -n) || - warn "Could not query maximum file descriptor limit" - ;; - esac - case $MAX_FD in #( - '' | soft) : ;; #( - *) - ulimit -n "$MAX_FD" || - warn "Could not set maximum file descriptor limit to $MAX_FD" - ;; - esac +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi # Collect all arguments for the java command, stacking in reverse order: @@ -164,36 +163,34 @@ fi # * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if "$cygwin" || "$msys"; then - APP_HOME=$(cygpath --path --mixed "$APP_HOME") - CLASSPATH=$(cygpath --path --mixed "$CLASSPATH") +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) - JAVACMD=$(cygpath --unix "$JAVACMD") + JAVACMD=$( cygpath --unix "$JAVACMD" ) - # Now convert the arguments - kludge to limit ourselves to /bin/sh - for arg; do - if - case $arg in #( - -*) false ;; # don't mess with options #( - /?*) - t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath - [ -e "$t" ] - ;; #( - *) false ;; - esac - then - arg=$(cygpath --path --ignore --mixed "$arg") - fi - # Roll the args list around exactly as many times as the number of - # args, so each arg winds up back in the position where it started, but - # possibly modified. - # - # NB: a `for` loop captures its iteration list before it begins, so - # changing the positional parameters here affects neither the number of - # iterations, nor the values presented in `arg`. - shift # remove old arg - set -- "$@" "$arg" # push replacement arg - done + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done fi # Collect all arguments for the java command; @@ -203,15 +200,10 @@ fi # * put everything else in single quotes, so that it's not re-expanded. set -- \ - "-Dorg.gradle.appname=$APP_BASE_NAME" \ - -classpath "$CLASSPATH" \ - org.gradle.wrapper.GradleWrapperMain \ - "$@" - -# Stop when "xargs" is not available. -if ! command -v xargs >/dev/null 2>&1; then - die "xargs is not available" -fi + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" # Use "xargs" to parse quoted args. # @@ -233,10 +225,10 @@ fi # eval "set -- $( - printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | - xargs -n1 | - sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | - tr '\n' ' ' -)" '"$@"' + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index f127cfd49..107acd32c 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%"=="" @echo off +@if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,7 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%"=="" set DIRNAME=. +if "%DIRNAME%" == "" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if %ERRORLEVEL% equ 0 goto execute +if "%ERRORLEVEL%" == "0" goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar :end @rem End local scope for the variables with windows NT shell -if %ERRORLEVEL% equ 0 goto mainEnd +if "%ERRORLEVEL%"=="0" goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -set EXIT_CODE=%ERRORLEVEL% -if %EXIT_CODE% equ 0 set EXIT_CODE=1 -if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% -exit /b %EXIT_CODE% +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 :mainEnd if "%OS%"=="Windows_NT" endlocal