gan example
Signed-off-by: brian <brian@brutex.de>
This commit is contained in:
		
							parent
							
								
									1b3338f809
								
							
						
					
					
						commit
						dd151aec3f
					
				@ -1,115 +1,48 @@
 | 
			
		||||
/*
 | 
			
		||||
 *
 | 
			
		||||
 *    ******************************************************************************
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * This program and the accompanying materials are made available under the
 | 
			
		||||
 *    * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    *  See the NOTICE file distributed with this work for additional
 | 
			
		||||
 *    *  information regarding copyright ownership.
 | 
			
		||||
 *    * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 *    * License for the specific language governing permissions and limitations
 | 
			
		||||
 *    * under the License.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 *    *****************************************************************************
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package net.brutex.gan;
 | 
			
		||||
 | 
			
		||||
import static net.brutex.ai.dnn.api.NN.dense;
 | 
			
		||||
 | 
			
		||||
import java.awt.*;
 | 
			
		||||
import java.awt.image.BufferedImage;
 | 
			
		||||
import java.io.File;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Random;
 | 
			
		||||
import java.util.UUID;
 | 
			
		||||
import javax.imageio.ImageIO;
 | 
			
		||||
import javax.swing.ImageIcon;
 | 
			
		||||
import javax.swing.JFrame;
 | 
			
		||||
import javax.swing.JLabel;
 | 
			
		||||
import javax.swing.JPanel;
 | 
			
		||||
import javax.swing.WindowConstants;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import javax.swing.*;
 | 
			
		||||
import org.apache.commons.lang3.ArrayUtils;
 | 
			
		||||
import org.datavec.api.split.FileSplit;
 | 
			
		||||
import org.datavec.image.loader.NativeImageLoader;
 | 
			
		||||
import org.datavec.image.recordreader.ImageRecordReader;
 | 
			
		||||
import org.datavec.image.transform.ColorConversionTransform;
 | 
			
		||||
import org.datavec.image.transform.ImageTransform;
 | 
			
		||||
import org.datavec.image.transform.PipelineImageTransform;
 | 
			
		||||
import org.datavec.image.transform.ResizeImageTransform;
 | 
			
		||||
import org.datavec.image.transform.ShowImageTransform;
 | 
			
		||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
 | 
			
		||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
			
		||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
			
		||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
 | 
			
		||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
			
		||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
			
		||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
 | 
			
		||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
			
		||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
import org.nd4j.linalg.activations.Activation;
 | 
			
		||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.DataSet;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.nd4j.linalg.learning.config.Adam;
 | 
			
		||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
			
		||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
 | 
			
		||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
			
		||||
 | 
			
		||||
@Slf4j
 | 
			
		||||
public class App {
 | 
			
		||||
  private static final double LEARNING_RATE = 0.000002;
 | 
			
		||||
    private static final double LEARNING_RATE = 0.002;
 | 
			
		||||
    private static final double GRADIENT_THRESHOLD = 100.0;
 | 
			
		||||
 | 
			
		||||
  private static final int X_DIM = 20 ;
 | 
			
		||||
  private static final int Y_DIM = 20;
 | 
			
		||||
  private static final int CHANNELS = 1;
 | 
			
		||||
  private static final int batchSize = 1;
 | 
			
		||||
  private static final int INPUT = 10;
 | 
			
		||||
 | 
			
		||||
  private static final int OUTPUT_PER_PANEL = 16;
 | 
			
		||||
 | 
			
		||||
  private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
 | 
			
		||||
    private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
 | 
			
		||||
 | 
			
		||||
    private static final int BATCHSIZE = 128;
 | 
			
		||||
    private static JFrame frame;
 | 
			
		||||
  private static  JFrame frame2;
 | 
			
		||||
    private static JPanel panel;
 | 
			
		||||
  private static JPanel panel2;
 | 
			
		||||
 | 
			
		||||
  private static final String OUTPUT_DIR = "C:/temp/output/";
 | 
			
		||||
 | 
			
		||||
    private static LayerConfiguration[] genLayers() {
 | 
			
		||||
        return new LayerConfiguration[] {
 | 
			
		||||
        DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
 | 
			
		||||
        ActivationLayer.builder(Activation.LEAKYRELU).build(),
 | 
			
		||||
 | 
			
		||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
			
		||||
                dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
 | 
			
		||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
            DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
 | 
			
		||||
                dense().nIn(256).nOut(512).build(),
 | 
			
		||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
 | 
			
		||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
 | 
			
		||||
                dense().nIn(512).nOut(1024).build(),
 | 
			
		||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
                dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -124,65 +57,51 @@ public class App {
 | 
			
		||||
                .updater(UPDATER)
 | 
			
		||||
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
			
		||||
        //.weightInit(WeightInit.XAVIER)
 | 
			
		||||
                .weightInit(WeightInit.XAVIER)
 | 
			
		||||
                .activation(Activation.IDENTITY)
 | 
			
		||||
                .layersFromArray(genLayers())
 | 
			
		||||
        .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
			
		||||
       // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
 | 
			
		||||
                .name("generator")
 | 
			
		||||
                .build();
 | 
			
		||||
    ((NeuralNetConfiguration) conf).init();
 | 
			
		||||
 | 
			
		||||
        return conf;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static LayerConfiguration[] disLayers() {
 | 
			
		||||
        return new LayerConfiguration[]{
 | 
			
		||||
        DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
 | 
			
		||||
                dense().nIn(784).nOut(1024).build(),
 | 
			
		||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
                DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
        DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
 | 
			
		||||
                dense().nIn(1024).nOut(512).build(),
 | 
			
		||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
                DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
        DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
 | 
			
		||||
                dense().nIn(512).nOut(256).build(),
 | 
			
		||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
                DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
        DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
			
		||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
        DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
 | 
			
		||||
        OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
 | 
			
		||||
                OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static NeuralNetConfiguration discriminator() {
 | 
			
		||||
 | 
			
		||||
    NeuralNetConfiguration conf =
 | 
			
		||||
        NeuralNetConfiguration.builder()
 | 
			
		||||
        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
			
		||||
                .seed(42)
 | 
			
		||||
                .updater(UPDATER)
 | 
			
		||||
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
			
		||||
                .weightInit(WeightInit.XAVIER)
 | 
			
		||||
            //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
			
		||||
                .weightNoise(null)
 | 
			
		||||
            // .weightInitFn(new WeightInitXavier())
 | 
			
		||||
            // .activationFn(new ActivationIdentity())
 | 
			
		||||
                .activation(Activation.IDENTITY)
 | 
			
		||||
                .layersFromArray(disLayers())
 | 
			
		||||
            .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
			
		||||
                .name("discriminator")
 | 
			
		||||
                .build();
 | 
			
		||||
    ((NeuralNetConfiguration) conf).init();
 | 
			
		||||
 | 
			
		||||
        return conf;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static NeuralNetConfiguration gan() {
 | 
			
		||||
        LayerConfiguration[] genLayers = genLayers();
 | 
			
		||||
    LayerConfiguration[] disLayers = Arrays.stream(disLayers())
 | 
			
		||||
        LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
 | 
			
		||||
                .map((layer) -> {
 | 
			
		||||
                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
			
		||||
          return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
 | 
			
		||||
                        return FrozenLayerWithBackprop.builder(layer).build();
 | 
			
		||||
                    } else {
 | 
			
		||||
                        return layer;
 | 
			
		||||
                    }
 | 
			
		||||
@ -191,174 +110,100 @@ public class App {
 | 
			
		||||
 | 
			
		||||
        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
			
		||||
                .seed(42)
 | 
			
		||||
        .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
 | 
			
		||||
                .updater(UPDATER)
 | 
			
		||||
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
        .gradientNormalizationThreshold( 100 )
 | 
			
		||||
        //.weightInitFn( new WeightInitXavier() ) //this is internal
 | 
			
		||||
            .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
			
		||||
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
			
		||||
                .weightInit(WeightInit.XAVIER)
 | 
			
		||||
        //.activationFn( new ActivationIdentity()) //this is internal
 | 
			
		||||
                .activation(Activation.IDENTITY)
 | 
			
		||||
                .layersFromArray(layers)
 | 
			
		||||
        .inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
			
		||||
            .dataType(DataType.FLOAT)
 | 
			
		||||
                .name("GAN")
 | 
			
		||||
                .build();
 | 
			
		||||
((NeuralNetConfiguration) conf).init();
 | 
			
		||||
 | 
			
		||||
        return conf;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void runTest() throws Exception {
 | 
			
		||||
      if(! log.isDebugEnabled()) {
 | 
			
		||||
          log.info("Logging is not set to DEBUG");
 | 
			
		||||
        App.main(null);
 | 
			
		||||
    }
 | 
			
		||||
      else {
 | 
			
		||||
          log.info("Logging is set to DEBUG");
 | 
			
		||||
      }
 | 
			
		||||
    main();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    public static void main(String... args) throws Exception {
 | 
			
		||||
        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
			
		||||
 | 
			
		||||
    log.info("\u001B[32m  Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m   ");
 | 
			
		||||
    Nd4j.getMemoryManager().setAutoGcWindow(500);
 | 
			
		||||
 | 
			
		||||
   //MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
 | 
			
		||||
   //FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
 | 
			
		||||
   FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
 | 
			
		||||
 | 
			
		||||
    ImageTransform transform2 = new ShowImageTransform("Tester", 30);
 | 
			
		||||
    ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
 | 
			
		||||
 | 
			
		||||
    ImageTransform tr = new PipelineImageTransform.Builder()
 | 
			
		||||
        //.addImageTransform(transform) //convert to GREY SCALE
 | 
			
		||||
        .addImageTransform(transform3)
 | 
			
		||||
        //.addImageTransform(transform2)
 | 
			
		||||
        .build();
 | 
			
		||||
 | 
			
		||||
    ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
 | 
			
		||||
    imageRecordReader.initialize(fileSplit, tr);
 | 
			
		||||
    DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
 | 
			
		||||
        MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
 | 
			
		||||
 | 
			
		||||
        MultiLayerNetwork gen = new MultiLayerNetwork(generator());
 | 
			
		||||
        MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
 | 
			
		||||
        MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
			
		||||
    gen.init(); log.debug("Generator network: {}", gen);
 | 
			
		||||
    dis.init(); log.debug("Discriminator network: {}", dis);
 | 
			
		||||
    gan.init(); log.info("Complete GAN network: {}", gan);
 | 
			
		||||
 | 
			
		||||
        gen.init();
 | 
			
		||||
        dis.init();
 | 
			
		||||
        gan.init();
 | 
			
		||||
 | 
			
		||||
        copyParams(gen, dis, gan);
 | 
			
		||||
 | 
			
		||||
    //gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
 | 
			
		||||
    dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
			
		||||
    gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
			
		||||
    //gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
			
		||||
    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
			
		||||
        gen.addTrainingListeners(new PerformanceListener(10, true));
 | 
			
		||||
        dis.addTrainingListeners(new PerformanceListener(10, true));
 | 
			
		||||
        gan.addTrainingListeners(new PerformanceListener(10, true));
 | 
			
		||||
 | 
			
		||||
    //System.out.println(gan.toString());
 | 
			
		||||
    //gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
			
		||||
 | 
			
		||||
    //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
 | 
			
		||||
    //trainData.reset();
 | 
			
		||||
        trainData.reset();
 | 
			
		||||
 | 
			
		||||
        int j = 0;
 | 
			
		||||
    for (int i = 0; i < 51; i++) { //epoch
 | 
			
		||||
        for (int i = 0; i < 50; i++) {
 | 
			
		||||
            while (trainData.hasNext()) {
 | 
			
		||||
                j++;
 | 
			
		||||
 | 
			
		||||
        DataSet next = trainData.next();
 | 
			
		||||
                // generate data
 | 
			
		||||
        INDArray real = next.getFeatures();//.div(255f);
 | 
			
		||||
 | 
			
		||||
        //start next round if there are not enough images left to have a full batchsize dataset
 | 
			
		||||
        if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
 | 
			
		||||
          log.warn("Your total number of input images is not a multiple of {}, "
 | 
			
		||||
              + "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
 | 
			
		||||
        break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //if(i%20 == 0) {
 | 
			
		||||
         frame2 = visualize(new INDArray[]{real}, batchSize,
 | 
			
		||||
         frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
 | 
			
		||||
        //}
 | 
			
		||||
       real.divi(255f);
 | 
			
		||||
 | 
			
		||||
//        int batchSize = (int) real.shape()[0];
 | 
			
		||||
 | 
			
		||||
        INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
			
		||||
        //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
 | 
			
		||||
 | 
			
		||||
                INDArray real = trainData.next().getFeatures().muli(2).subi(1);
 | 
			
		||||
                int batchSize = (int) real.shape()[0];
 | 
			
		||||
 | 
			
		||||
                INDArray fakeIn = Nd4j.rand(batchSize, 100);
 | 
			
		||||
                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
			
		||||
        fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
			
		||||
 | 
			
		||||
        //log.info("real has {} items.", real.length());
 | 
			
		||||
                DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
 | 
			
		||||
                DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
			
		||||
 | 
			
		||||
                dis.fit(data);
 | 
			
		||||
        //dis.fit(data);
 | 
			
		||||
                dis.fit(data);
 | 
			
		||||
 | 
			
		||||
                // Update the discriminator in the GAN network
 | 
			
		||||
                updateGan(gen, dis, gan);
 | 
			
		||||
 | 
			
		||||
        //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
 | 
			
		||||
        gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
 | 
			
		||||
                gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        //Visualize and reporting
 | 
			
		||||
                if (j % 10 == 1) {
 | 
			
		||||
                    System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
 | 
			
		||||
          INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
          for (int k = 0; k < samples.length; k++) {
 | 
			
		||||
            //INDArray input = fakeSet2.get(k).getFeatures();
 | 
			
		||||
                    INDArray[] samples = new INDArray[9];
 | 
			
		||||
                    DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
 | 
			
		||||
            INDArray input = fakeSet2.get(k).getFeatures();
 | 
			
		||||
            input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
 | 
			
		||||
 | 
			
		||||
                    for (int k = 0; k < 9; k++) {
 | 
			
		||||
                        INDArray input = fakeSet2.get(k).getFeatures();
 | 
			
		||||
                        //samples[k] = gen.output(input, false);
 | 
			
		||||
                        samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
			
		||||
            samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
 | 
			
		||||
            //samples[k] =
 | 
			
		||||
            samples[k].addi(1f).divi(2f).muli(255f);
 | 
			
		||||
 | 
			
		||||
                    }
 | 
			
		||||
          frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
 | 
			
		||||
                    visualize(samples);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
      if (trainData.resetSupported()) {
 | 
			
		||||
            trainData.reset();
 | 
			
		||||
      } else {
 | 
			
		||||
          log.error("Trainingdata {} does not support reset.", trainData.toString());
 | 
			
		||||
            // Copy the GANs generator to gen.
 | 
			
		||||
            //updateGen(gen, gan);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Copy the GANs generator to gen.
 | 
			
		||||
        updateGen(gen, gan);
 | 
			
		||||
 | 
			
		||||
        gen.save(new File("mnist-mlp-generator.dlj"));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
			
		||||
        int genLayerCount = gen.getLayers().length;
 | 
			
		||||
        for (int i = 0; i < gan.getLayers().length; i++) {
 | 
			
		||||
            if (i < genLayerCount) {
 | 
			
		||||
        if(gan.getLayer(i).getParams() != null)
 | 
			
		||||
         gan.getLayer(i).setParams(gen.getLayer(i).getParams());
 | 
			
		||||
                gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
			
		||||
            } else {
 | 
			
		||||
        if(gan.getLayer(i).getParams() != null)
 | 
			
		||||
        gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
 | 
			
		||||
                dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@ -376,98 +221,41 @@ public class App {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
 | 
			
		||||
    if (isOrig) {
 | 
			
		||||
      frame.setTitle("Viz Original");
 | 
			
		||||
    } else {
 | 
			
		||||
      frame.setTitle("Generated");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static void visualize(INDArray[] samples) {
 | 
			
		||||
        if (frame == null) {
 | 
			
		||||
            frame = new JFrame();
 | 
			
		||||
            frame.setTitle("Viz");
 | 
			
		||||
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
			
		||||
            frame.setLayout(new BorderLayout());
 | 
			
		||||
 | 
			
		||||
    JPanel panelx = new JPanel();
 | 
			
		||||
            panel = new JPanel();
 | 
			
		||||
 | 
			
		||||
    panelx.setLayout(new GridLayout(4, 4, 8, 8));
 | 
			
		||||
    for (INDArray sample : samples) {
 | 
			
		||||
      for(int i = 0; i<batchElements; i++) {
 | 
			
		||||
        panelx.add(getImage(sample, i, isOrig));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    frame.add(panelx, BorderLayout.CENTER);
 | 
			
		||||
            panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
 | 
			
		||||
            frame.add(panel, BorderLayout.CENTER);
 | 
			
		||||
            frame.setVisible(true);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        panel.removeAll();
 | 
			
		||||
 | 
			
		||||
        for (INDArray sample : samples) {
 | 
			
		||||
            panel.add(getImage(sample));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        frame.revalidate();
 | 
			
		||||
    frame.setMinimumSize(new Dimension(300, 20));
 | 
			
		||||
        frame.pack();
 | 
			
		||||
    return frame;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
			
		||||
    final BufferedImage bi;
 | 
			
		||||
    if(CHANNELS>1) {
 | 
			
		||||
        bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
 | 
			
		||||
    } else {
 | 
			
		||||
        bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
 | 
			
		||||
    }
 | 
			
		||||
    final int imageSize = X_DIM * Y_DIM;
 | 
			
		||||
    final int offset = batchElement * imageSize;
 | 
			
		||||
    int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
			
		||||
 | 
			
		||||
    //Image in NCHW - channels first format
 | 
			
		||||
    for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
			
		||||
      for (int y = 0; y < Y_DIM; y++) { // step through the columns x
 | 
			
		||||
        for (int x = 0; x < X_DIM; x++) { //step through the rows y
 | 
			
		||||
          if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
 | 
			
		||||
          bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
 | 
			
		||||
          pxl++; //next item in INDArray
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    private static JLabel getImage(INDArray tensor) {
 | 
			
		||||
        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
 | 
			
		||||
        for (int i = 0; i < 784; i++) {
 | 
			
		||||
            int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
 | 
			
		||||
            bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
 | 
			
		||||
        }
 | 
			
		||||
        ImageIcon orig = new ImageIcon(bi);
 | 
			
		||||
    Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
			
		||||
        Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
 | 
			
		||||
 | 
			
		||||
        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
			
		||||
    if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
			
		||||
 | 
			
		||||
        return new JLabel(scaled);
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
			
		||||
    String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
			
		||||
 | 
			
		||||
    try {
 | 
			
		||||
      // Save the images to disk
 | 
			
		||||
      saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
			
		||||
 | 
			
		||||
     log.debug("Images saved successfully.");
 | 
			
		||||
    } catch (IOException e) {
 | 
			
		||||
      log.error("Error saving the images: {}", e.getMessage());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
			
		||||
        File directory = new File(outputDirectory);
 | 
			
		||||
        if (!directory.exists()) {
 | 
			
		||||
            directory.mkdir();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        File outputFile = new File(directory, fileName);
 | 
			
		||||
        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static BufferedImage imageToBufferedImage(Image image) {
 | 
			
		||||
        if (image instanceof BufferedImage) {
 | 
			
		||||
            return (BufferedImage) image;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Create a buffered image with the same dimensions and transparency as the original image
 | 
			
		||||
        BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
			
		||||
 | 
			
		||||
        // Draw the original image onto the buffered image
 | 
			
		||||
        Graphics2D g2d = bufferedImage.createGraphics();
 | 
			
		||||
        g2d.drawImage(image, 0, 0, null);
 | 
			
		||||
        g2d.dispose();
 | 
			
		||||
 | 
			
		||||
        return bufferedImage;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										343
									
								
								brutex-extended-tests/src/test/java/net/brutex/gan/App2.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										343
									
								
								brutex-extended-tests/src/test/java/net/brutex/gan/App2.java
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,343 @@
 | 
			
		||||
/*
 | 
			
		||||
 *
 | 
			
		||||
 *    ******************************************************************************
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * This program and the accompanying materials are made available under the
 | 
			
		||||
 *    * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    *  See the NOTICE file distributed with this work for additional
 | 
			
		||||
 *    *  information regarding copyright ownership.
 | 
			
		||||
 *    * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 *    * License for the specific language governing permissions and limitations
 | 
			
		||||
 *    * under the License.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 *    *****************************************************************************
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package net.brutex.gan;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import java.awt.*;
 | 
			
		||||
import java.awt.image.BufferedImage;
 | 
			
		||||
import java.io.File;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.util.*;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import javax.imageio.ImageIO;
 | 
			
		||||
import javax.swing.*;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.datavec.api.split.FileSplit;
 | 
			
		||||
import org.datavec.image.loader.NativeImageLoader;
 | 
			
		||||
import org.datavec.image.recordreader.ImageRecordReader;
 | 
			
		||||
import org.datavec.image.transform.*;
 | 
			
		||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
			
		||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
			
		||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.DataSet;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
@Slf4j
 | 
			
		||||
public class App2 {
 | 
			
		||||
 | 
			
		||||
    final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
 | 
			
		||||
    static final float COLORSPACE = 255f;
 | 
			
		||||
    static final int DIMENSIONS = 28;
 | 
			
		||||
    static final int CHANNELS = 1;
 | 
			
		||||
    final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
 | 
			
		||||
    final int OUTPUT_PER_PANEL = 10;
 | 
			
		||||
 | 
			
		||||
    final boolean BIAS = true;
 | 
			
		||||
 | 
			
		||||
    static final int BATCHSIZE=128;
 | 
			
		||||
 | 
			
		||||
    private JFrame frame2, frame;
 | 
			
		||||
    static final String OUTPUT_DIR = "d:/out/";
 | 
			
		||||
 | 
			
		||||
    final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
 | 
			
		||||
    final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void runTest() throws IOException {
 | 
			
		||||
        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
			
		||||
 | 
			
		||||
        MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
 | 
			
		||||
        FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
 | 
			
		||||
        ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
 | 
			
		||||
        ImageTransform transform2 = new ShowImageTransform("Tester", 30);
 | 
			
		||||
        ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
 | 
			
		||||
 | 
			
		||||
        ImageTransform tr = new PipelineImageTransform.Builder()
 | 
			
		||||
               .addImageTransform(transform) //convert to GREY SCALE
 | 
			
		||||
                .addImageTransform(transform3)
 | 
			
		||||
                //.addImageTransform(transform2)
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
 | 
			
		||||
        imageRecordReader.initialize(fileSplit, tr);
 | 
			
		||||
        DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
 | 
			
		||||
        trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
 | 
			
		||||
 | 
			
		||||
        MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
 | 
			
		||||
        MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
 | 
			
		||||
 | 
			
		||||
        LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
 | 
			
		||||
                .map((layer) -> {
 | 
			
		||||
                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
			
		||||
                        return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
 | 
			
		||||
                    } else {
 | 
			
		||||
                        return layer;
 | 
			
		||||
                    }
 | 
			
		||||
                }).toArray(LayerConfiguration[]::new);
 | 
			
		||||
 | 
			
		||||
    NeuralNetConfiguration netConfiguration =
 | 
			
		||||
        NeuralNetConfiguration.builder()
 | 
			
		||||
            .name("GAN")
 | 
			
		||||
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
                .gradientNormalizationThreshold(100)
 | 
			
		||||
                .updater(App2Config.UPDATER)
 | 
			
		||||
            .innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
 | 
			
		||||
            .layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
 | 
			
		||||
            // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
			
		||||
            // .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
 | 
			
		||||
            //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
			
		||||
           // .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
			
		||||
                //.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
 | 
			
		||||
                //.dataType(DataType.FLOAT)
 | 
			
		||||
            .build();
 | 
			
		||||
 | 
			
		||||
        MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
 | 
			
		||||
 | 
			
		||||
        dis.init(); log.debug("Discriminator network: {}", dis);
 | 
			
		||||
        gen.init(); log.debug("Generator network: {}", gen);
 | 
			
		||||
        gan.init(); log.debug("GAN network: {}", gan);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        log.info("Generator Summary:\n{}", gen.summary());
 | 
			
		||||
        log.info("GAN Summary:\n{}", gan.summary());
 | 
			
		||||
        dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
			
		||||
        gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
 | 
			
		||||
        gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
			
		||||
 | 
			
		||||
        int j = 0;
 | 
			
		||||
        for (int i = 0; i < 51; i++) { //epoch
 | 
			
		||||
            while (trainData.hasNext()) {
 | 
			
		||||
                j++;
 | 
			
		||||
                DataSet next = trainData.next();
 | 
			
		||||
                // generate data
 | 
			
		||||
                INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
 | 
			
		||||
 | 
			
		||||
                //start next round if there are not enough images left to have a full batchsize dataset
 | 
			
		||||
                if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
 | 
			
		||||
                    log.warn("Your total number of input images is not a multiple of {}, "
 | 
			
		||||
                            + "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                //if(i%20 == 0) {
 | 
			
		||||
 | 
			
		||||
               // frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
 | 
			
		||||
               //         frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
 | 
			
		||||
                //}
 | 
			
		||||
                //real.divi(255f);
 | 
			
		||||
 | 
			
		||||
//        int batchSize = (int) real.shape()[0];
 | 
			
		||||
 | 
			
		||||
                //INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
 | 
			
		||||
                //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
 | 
			
		||||
                INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
 | 
			
		||||
 | 
			
		||||
                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
			
		||||
                // when generator has TANH as activation - value range is -1 to 1
 | 
			
		||||
                // when generator has SIGMOID, then range is 0 to 1
 | 
			
		||||
                fake.addi(1f).divi(2f);
 | 
			
		||||
 | 
			
		||||
                DataSet realSet = new DataSet(real, label_real);
 | 
			
		||||
                DataSet fakeSet = new DataSet(fake, label_fake);
 | 
			
		||||
 | 
			
		||||
                DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
			
		||||
 | 
			
		||||
                dis.fit(data);
 | 
			
		||||
                dis.fit(data);
 | 
			
		||||
                // Update the discriminator in the GAN network
 | 
			
		||||
                updateGan(gen, dis, gan);
 | 
			
		||||
 | 
			
		||||
                gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
 | 
			
		||||
 | 
			
		||||
                //Visualize and reporting
 | 
			
		||||
                if (j % 10 == 1) {
 | 
			
		||||
                    System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
 | 
			
		||||
                    INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                    for (int k = 0; k < samples.length; k++) {
 | 
			
		||||
                        DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
 | 
			
		||||
                        INDArray input = fakeSet2.get(k).getFeatures();
 | 
			
		||||
 | 
			
		||||
                        //input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
 | 
			
		||||
                        input = input.reshape(1, App2Config.INPUT);
 | 
			
		||||
 | 
			
		||||
                        //samples[k] = gen.output(input, false);
 | 
			
		||||
                        samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
			
		||||
                        samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
 | 
			
		||||
                        //samples[k] =
 | 
			
		||||
                        //samples[k].muli(255f);
 | 
			
		||||
 | 
			
		||||
                    }
 | 
			
		||||
                    frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (trainData.resetSupported()) {
 | 
			
		||||
                trainData.reset();
 | 
			
		||||
            } else {
 | 
			
		||||
                log.error("Trainingdata {} does not support reset.", trainData.toString());
 | 
			
		||||
            }
 | 
			
		||||
            // Copy the GANs generator to gen.
 | 
			
		||||
            updateGen(gen, gan);
 | 
			
		||||
            log.info("Updated GAN's generator from gen.");
 | 
			
		||||
            gen.save(new File("mnist-mlp-generator.dlj"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
 | 
			
		||||
        if (isOrig) {
 | 
			
		||||
            frame.setTitle("Viz Original");
 | 
			
		||||
        } else {
 | 
			
		||||
            frame.setTitle("Generated");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
			
		||||
        frame.setLayout(new BorderLayout());
 | 
			
		||||
 | 
			
		||||
        JPanel panelx = new JPanel();
 | 
			
		||||
 | 
			
		||||
        panelx.setLayout(new GridLayout(4, 4, 8, 8));
 | 
			
		||||
        for (INDArray sample : samples) {
 | 
			
		||||
            for(int i = 0; i<batchElements; i++) {
 | 
			
		||||
                panelx.add(getImage(sample, i, isOrig));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        frame.add(panelx, BorderLayout.CENTER);
 | 
			
		||||
        frame.setVisible(true);
 | 
			
		||||
 | 
			
		||||
        frame.revalidate();
 | 
			
		||||
        frame.setMinimumSize(new Dimension(300, 20));
 | 
			
		||||
        frame.pack();
 | 
			
		||||
        return frame;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
			
		||||
        final BufferedImage bi;
 | 
			
		||||
        if(CHANNELS >1) {
 | 
			
		||||
            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
 | 
			
		||||
        } else {
 | 
			
		||||
            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
 | 
			
		||||
        }
 | 
			
		||||
        final int imageSize = DIMENSIONS * DIMENSIONS;
 | 
			
		||||
        final int offset = batchElement * imageSize;
 | 
			
		||||
        int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
			
		||||
 | 
			
		||||
        //Image in NCHW - channels first format
 | 
			
		||||
        for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
			
		||||
            for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
 | 
			
		||||
                for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
 | 
			
		||||
                    float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
 | 
			
		||||
                    if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
 | 
			
		||||
                    bi.getRaster().setSample(x, y, c, f_pxl);
 | 
			
		||||
                    pxl++; //next item in INDArray
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        ImageIcon orig = new ImageIcon(bi);
 | 
			
		||||
        Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
 | 
			
		||||
        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
			
		||||
        if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
			
		||||
        return new JLabel(scaled);
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
			
		||||
        String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
			
		||||
 | 
			
		||||
        try {
 | 
			
		||||
            // Save the images to disk
 | 
			
		||||
            saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
			
		||||
 | 
			
		||||
            log.debug("Images saved successfully.");
 | 
			
		||||
        } catch (IOException e) {
 | 
			
		||||
            log.error("Error saving the images: {}", e.getMessage());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
			
		||||
        File directory = new File(outputDirectory);
 | 
			
		||||
        if (!directory.exists()) {
 | 
			
		||||
            directory.mkdir();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        File outputFile = new File(directory, fileName);
 | 
			
		||||
        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static BufferedImage imageToBufferedImage(Image image) {
 | 
			
		||||
        if (image instanceof BufferedImage) {
 | 
			
		||||
            return (BufferedImage) image;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    // Create a buffered image with the same dimensions and transparency as the original image
 | 
			
		||||
        BufferedImage bufferedImage;
 | 
			
		||||
    if (CHANNELS > 1) {
 | 
			
		||||
      bufferedImage =
 | 
			
		||||
          new BufferedImage(
 | 
			
		||||
              image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
			
		||||
        } else {
 | 
			
		||||
        bufferedImage =
 | 
			
		||||
                new BufferedImage(
 | 
			
		||||
                        image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Draw the original image onto the buffered image
 | 
			
		||||
        Graphics2D g2d = bufferedImage.createGraphics();
 | 
			
		||||
        g2d.drawImage(image, 0, 0, null);
 | 
			
		||||
        g2d.dispose();
 | 
			
		||||
 | 
			
		||||
        return bufferedImage;
 | 
			
		||||
    }
 | 
			
		||||
    private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
			
		||||
        for (int i = 0; i < gen.getLayers().length; i++) {
 | 
			
		||||
            gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
			
		||||
        int genLayerCount = gen.getLayers().length;
 | 
			
		||||
        for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
			
		||||
            gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,176 @@
 | 
			
		||||
/*
 | 
			
		||||
 *
 | 
			
		||||
 *    ******************************************************************************
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * This program and the accompanying materials are made available under the
 | 
			
		||||
 *    * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    *  See the NOTICE file distributed with this work for additional
 | 
			
		||||
 *    *  information regarding copyright ownership.
 | 
			
		||||
 *    * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 *    * License for the specific language governing permissions and limitations
 | 
			
		||||
 *    * under the License.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 *    *****************************************************************************
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package net.brutex.gan;
 | 
			
		||||
 | 
			
		||||
import static net.brutex.ai.dnn.api.NN.*;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
			
		||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
			
		||||
import org.nd4j.linalg.activations.Activation;
 | 
			
		||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
			
		||||
import org.nd4j.linalg.learning.config.Adam;
 | 
			
		||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
			
		||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
			
		||||
 | 
			
		||||
public class App2Config {
 | 
			
		||||
 | 
			
		||||
  public static final int INPUT = 100;
 | 
			
		||||
  public static final int X_DIM = 28;
 | 
			
		||||
  public static final int y_DIM = 28;
 | 
			
		||||
  public static final int CHANNELS = 1;
 | 
			
		||||
  public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
 | 
			
		||||
 | 
			
		||||
  static LayerConfiguration[] genLayerConfig() {
 | 
			
		||||
    return new LayerConfiguration[] {
 | 
			
		||||
            /*
 | 
			
		||||
      DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
 | 
			
		||||
      ActivationLayer.builder().activation(Activation.RELU).build(), /*
 | 
			
		||||
                Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
 | 
			
		||||
                        .kernelSize(2,2)
 | 
			
		||||
                        .stride(1,1)
 | 
			
		||||
                        .padding(0,0)
 | 
			
		||||
                        .convolutionMode(ConvolutionMode.Truncate)
 | 
			
		||||
                        .activation(Activation.RELU)
 | 
			
		||||
                        .hasBias(BIAS).build(),
 | 
			
		||||
                //BatchNormalization.builder().nOut(CHANNELS).build(),
 | 
			
		||||
                Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
 | 
			
		||||
                        .kernelSize(2,2)
 | 
			
		||||
                        .stride(2,2)
 | 
			
		||||
                        .padding(0,0)
 | 
			
		||||
                        .convolutionMode(ConvolutionMode.Truncate)
 | 
			
		||||
                        .activation(Activation.RELU)
 | 
			
		||||
                        .hasBias(BIAS).build(),
 | 
			
		||||
               //BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
      DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
 | 
			
		||||
      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
			
		||||
      DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
 | 
			
		||||
      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
			
		||||
      DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
 | 
			
		||||
      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
			
		||||
      // DropoutLayer.builder(0.001).build(),
 | 
			
		||||
      DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build()    */
 | 
			
		||||
 | 
			
		||||
            dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
 | 
			
		||||
            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
            dense().nIn(256).nOut(512).build(),
 | 
			
		||||
            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
            dense().nIn(512).nOut(1024).build(),
 | 
			
		||||
            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
            dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
 | 
			
		||||
 | 
			
		||||
    };
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static LayerConfiguration[] disLayerConfig() {
 | 
			
		||||
    return new LayerConfiguration[] {/*
 | 
			
		||||
                Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
 | 
			
		||||
                        .build(),
 | 
			
		||||
                Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
 | 
			
		||||
                        .build(),
 | 
			
		||||
                ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
 | 
			
		||||
                BatchNormalization.builder().build(),
 | 
			
		||||
                OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
			
		||||
                        .activation(Activation.SIGMOID)
 | 
			
		||||
                        .build()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
 | 
			
		||||
            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
			
		||||
            DropoutLayer.builder(0.5).build(),
 | 
			
		||||
 | 
			
		||||
            DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
 | 
			
		||||
            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
			
		||||
            DropoutLayer.builder(0.5).build(),
 | 
			
		||||
 | 
			
		||||
            DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
 | 
			
		||||
            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
			
		||||
            DropoutLayer.builder(0.5).build(),
 | 
			
		||||
 | 
			
		||||
            OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
 | 
			
		||||
                    .activation(Activation.SIGMOID)
 | 
			
		||||
                    .build() */
 | 
			
		||||
            dense().nIn(784).nOut(1024).hasBias(true).build(),
 | 
			
		||||
            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
            DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
            dense().nIn(1024).nOut(512).hasBias(true).build(),
 | 
			
		||||
            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
            DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
            dense().nIn(512).nOut(256).hasBias(true).build(),
 | 
			
		||||
            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
			
		||||
            DropoutLayer.builder(1 - 0.5).build(),
 | 
			
		||||
            OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
 | 
			
		||||
    };
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  static NeuralNetConfiguration generator() {
 | 
			
		||||
    NeuralNetConfiguration conf =
 | 
			
		||||
            NeuralNetConfiguration.builder()
 | 
			
		||||
                    .name("generator")
 | 
			
		||||
                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
                    .gradientNormalizationThreshold(100)
 | 
			
		||||
                    .seed(42)
 | 
			
		||||
                    .updater(UPDATER)
 | 
			
		||||
                    .weightInit(WeightInit.XAVIER)
 | 
			
		||||
                    //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
			
		||||
                    .weightNoise(null)
 | 
			
		||||
                    // .weightInitFn(new WeightInitXavier())
 | 
			
		||||
                    // .activationFn(new ActivationIdentity())
 | 
			
		||||
                    .activation(Activation.IDENTITY)
 | 
			
		||||
                    .layersFromArray(App2Config.genLayerConfig())
 | 
			
		||||
                    // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
			
		||||
                    //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
			
		||||
                    //.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
			
		||||
                    //.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
 | 
			
		||||
 | 
			
		||||
                    .build();
 | 
			
		||||
    conf.init();
 | 
			
		||||
    return conf;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static NeuralNetConfiguration discriminator() {
 | 
			
		||||
    NeuralNetConfiguration conf =
 | 
			
		||||
            NeuralNetConfiguration.builder()
 | 
			
		||||
                    .name("discriminator")
 | 
			
		||||
                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
                    .gradientNormalizationThreshold(100)
 | 
			
		||||
                    .seed(42)
 | 
			
		||||
                    .updater(UPDATER)
 | 
			
		||||
                    .weightInit(WeightInit.XAVIER)
 | 
			
		||||
                    // .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
			
		||||
                    .weightNoise(null)
 | 
			
		||||
                    // .weightInitFn(new WeightInitXavier())
 | 
			
		||||
                    // .activationFn(new ActivationIdentity())
 | 
			
		||||
                    .activation(Activation.IDENTITY)
 | 
			
		||||
                    .layersFromArray(disLayerConfig())
 | 
			
		||||
                    //.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
			
		||||
                    //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
			
		||||
                    //.dataType(DataType.FLOAT)
 | 
			
		||||
                    .build();
 | 
			
		||||
    conf.init();
 | 
			
		||||
    return conf;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -52,28 +52,44 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
   * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
 | 
			
		||||
   */
 | 
			
		||||
  public KerasSequentialModel(KerasModelBuilder modelBuilder)
 | 
			
		||||
            throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
 | 
			
		||||
        this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
 | 
			
		||||
                modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
 | 
			
		||||
                modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
 | 
			
		||||
      throws UnsupportedKerasConfigurationException,
 | 
			
		||||
          IOException,
 | 
			
		||||
          InvalidKerasConfigurationException {
 | 
			
		||||
    this(
 | 
			
		||||
        modelBuilder.getModelJson(),
 | 
			
		||||
        modelBuilder.getModelYaml(),
 | 
			
		||||
        modelBuilder.getWeightsArchive(),
 | 
			
		||||
        modelBuilder.getWeightsRoot(),
 | 
			
		||||
        modelBuilder.getTrainingJson(),
 | 
			
		||||
        modelBuilder.getTrainingArchive(),
 | 
			
		||||
        modelBuilder.isEnforceTrainingConfig(),
 | 
			
		||||
        modelBuilder.getInputShape());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
     * (Not recommended) Constructor for Sequential model from model configuration
 | 
			
		||||
     * (JSON or YAML), training configuration (JSON), weights, and "training mode"
 | 
			
		||||
     * boolean indicator. When built in training mode, certain unsupported configurations
 | 
			
		||||
     * (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
 | 
			
		||||
     * will generate warnings but will be otherwise ignored.
 | 
			
		||||
   * (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
 | 
			
		||||
   * training configuration (JSON), weights, and "training mode" boolean indicator. When built in
 | 
			
		||||
   * training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
 | 
			
		||||
   * Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
 | 
			
		||||
   * otherwise ignored.
 | 
			
		||||
   *
 | 
			
		||||
   * @param modelJson model configuration JSON string
 | 
			
		||||
   * @param modelYaml model configuration YAML string
 | 
			
		||||
   * @param trainingJson training configuration JSON string
 | 
			
		||||
   * @throws IOException I/O exception
 | 
			
		||||
   */
 | 
			
		||||
    public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
 | 
			
		||||
                                String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
 | 
			
		||||
  public KerasSequentialModel(
 | 
			
		||||
      String modelJson,
 | 
			
		||||
      String modelYaml,
 | 
			
		||||
      Hdf5Archive weightsArchive,
 | 
			
		||||
      String weightsRoot,
 | 
			
		||||
      String trainingJson,
 | 
			
		||||
      Hdf5Archive trainingArchive,
 | 
			
		||||
      boolean enforceTrainingConfig,
 | 
			
		||||
      int[] inputShape)
 | 
			
		||||
            throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
			
		||||
      throws IOException,
 | 
			
		||||
          InvalidKerasConfigurationException,
 | 
			
		||||
          UnsupportedKerasConfigurationException {
 | 
			
		||||
 | 
			
		||||
    Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
 | 
			
		||||
    this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
 | 
			
		||||
@ -83,19 +99,29 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
    /* Determine model configuration type. */
 | 
			
		||||
    if (!modelConfig.containsKey(config.getFieldClassName()))
 | 
			
		||||
      throw new InvalidKerasConfigurationException(
 | 
			
		||||
                    "Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
 | 
			
		||||
          "Could not determine Keras model class (no "
 | 
			
		||||
              + config.getFieldClassName()
 | 
			
		||||
              + " field found)");
 | 
			
		||||
    this.className = (String) modelConfig.get(config.getFieldClassName());
 | 
			
		||||
    if (!this.className.equals(config.getFieldClassNameSequential()))
 | 
			
		||||
            throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
 | 
			
		||||
                    + " (found " + this.className + ")");
 | 
			
		||||
      throw new InvalidKerasConfigurationException(
 | 
			
		||||
          "Model class name must be "
 | 
			
		||||
              + config.getFieldClassNameSequential()
 | 
			
		||||
              + " (found "
 | 
			
		||||
              + this.className
 | 
			
		||||
              + ")");
 | 
			
		||||
 | 
			
		||||
    /* Process layer configurations. */
 | 
			
		||||
    if (!modelConfig.containsKey(config.getModelFieldConfig()))
 | 
			
		||||
      throw new InvalidKerasConfigurationException(
 | 
			
		||||
                    "Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
 | 
			
		||||
          "Could not find layer configurations (no "
 | 
			
		||||
              + config.getModelFieldConfig()
 | 
			
		||||
              + " field found)");
 | 
			
		||||
 | 
			
		||||
        // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
 | 
			
		||||
        // "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
 | 
			
		||||
    // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
 | 
			
		||||
    // For consistency
 | 
			
		||||
    // "config" is now an object containing a "name" and "layers", the latter contain the same data
 | 
			
		||||
    // as before.
 | 
			
		||||
    // This change only affects Sequential models.
 | 
			
		||||
    List<Object> layerList;
 | 
			
		||||
    try {
 | 
			
		||||
@ -105,8 +131,7 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
      layerList = (List<Object>) layerMap.get("layers");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
        Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
 | 
			
		||||
                prepareLayers(layerList);
 | 
			
		||||
    Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
 | 
			
		||||
    this.layers = layerPair.getFirst();
 | 
			
		||||
    this.layersOrdered = layerPair.getSecond();
 | 
			
		||||
 | 
			
		||||
@ -116,15 +141,18 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
    } else {
 | 
			
		||||
      /* Add placeholder input layer and update lists of input and output layers. */
 | 
			
		||||
      int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
 | 
			
		||||
            Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
 | 
			
		||||
      Preconditions.checkState(
 | 
			
		||||
          ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
 | 
			
		||||
      inputLayer = new KerasInput("input1", firstLayerInputShape);
 | 
			
		||||
      inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
 | 
			
		||||
      this.layers.put(inputLayer.getName(), inputLayer);
 | 
			
		||||
      this.layersOrdered.add(0, inputLayer);
 | 
			
		||||
    }
 | 
			
		||||
    this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
 | 
			
		||||
        this.outputLayerNames = new ArrayList<>(
 | 
			
		||||
                Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
 | 
			
		||||
    this.outputLayerNames =
 | 
			
		||||
        new ArrayList<>(
 | 
			
		||||
            Collections.singletonList(
 | 
			
		||||
                this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
 | 
			
		||||
 | 
			
		||||
    /* Update each layer's inbound layer list to include (only) previous layer. */
 | 
			
		||||
    KerasLayer prevLayer = null;
 | 
			
		||||
@ -136,12 +164,13 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
 | 
			
		||||
    /* Import training configuration. */
 | 
			
		||||
    if (enforceTrainingConfig) {
 | 
			
		||||
            if (trainingJson != null)
 | 
			
		||||
                importTrainingConfiguration(trainingJson);
 | 
			
		||||
            else log.warn("If enforceTrainingConfig is true, a training " +
 | 
			
		||||
                    "configuration object has to be provided. Usually the only practical way to do this is to store" +
 | 
			
		||||
                    " your keras model with `model.save('model_path.h5'. If you store model config and weights" +
 | 
			
		||||
                    " separately no training configuration is attached.");
 | 
			
		||||
      if (trainingJson != null) importTrainingConfiguration(trainingJson);
 | 
			
		||||
      else
 | 
			
		||||
        log.warn(
 | 
			
		||||
            "If enforceTrainingConfig is true, a training "
 | 
			
		||||
                + "configuration object has to be provided. Usually the only practical way to do this is to store"
 | 
			
		||||
                + " your keras model with `model.save('model_path.h5'. If you store model config and weights"
 | 
			
		||||
                + " separately no training configuration is attached.");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    this.outputTypes = inferOutputTypes(inputShape);
 | 
			
		||||
@ -150,9 +179,7 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
      importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Default constructor
 | 
			
		||||
     */
 | 
			
		||||
  /** Default constructor */
 | 
			
		||||
  public KerasSequentialModel() {
 | 
			
		||||
    super();
 | 
			
		||||
  }
 | 
			
		||||
@ -174,13 +201,13 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
      throw new InvalidKerasConfigurationException(
 | 
			
		||||
          "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
 | 
			
		||||
 | 
			
		||||
        NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
 | 
			
		||||
    NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
 | 
			
		||||
        NeuralNetConfiguration.builder();
 | 
			
		||||
 | 
			
		||||
    if (optimizer != null) {
 | 
			
		||||
      modelBuilder.updater(optimizer);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    // don't forcibly override for keras import
 | 
			
		||||
    modelBuilder.overrideNinUponBuild(false);
 | 
			
		||||
    /* Add layers one at a time. */
 | 
			
		||||
@ -192,7 +219,10 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
        if (nbInbound != 1)
 | 
			
		||||
          throw new InvalidKerasConfigurationException(
 | 
			
		||||
              "Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
 | 
			
		||||
                                    + nbInbound + " for layer " + layer.getName() + ")");
 | 
			
		||||
                  + nbInbound
 | 
			
		||||
                  + " for layer "
 | 
			
		||||
                  + layer.getName()
 | 
			
		||||
                  + ")");
 | 
			
		||||
        if (prevLayer != null) {
 | 
			
		||||
          InputType[] inputTypes = new InputType[1];
 | 
			
		||||
          InputPreProcessor preprocessor;
 | 
			
		||||
@ -207,35 +237,37 @@ public class KerasSequentialModel extends KerasModel {
 | 
			
		||||
            if (preprocessor != null) {
 | 
			
		||||
              InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
			
		||||
              layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
 | 
			
		||||
            } else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
 | 
			
		||||
          }
 | 
			
		||||
                        else
 | 
			
		||||
                            layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
 | 
			
		||||
          if (preprocessor != null) {
 | 
			
		||||
 | 
			
		||||
            Map<Integer, InputPreProcessor> map = new HashMap<>();
 | 
			
		||||
            map.put(layerIndex, preprocessor);
 | 
			
		||||
            modelBuilder.inputPreProcessors(map);
 | 
			
		||||
          }
 | 
			
		||||
                    if (preprocessor != null)
 | 
			
		||||
                        modelBuilder.inputPreProcessor(layerIndex, preprocessor);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        modelBuilder.layer(layerIndex++, layer.getLayer());
 | 
			
		||||
      } else if (layer.getVertex() != null)
 | 
			
		||||
                throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
 | 
			
		||||
                        + layer.getClassName() + ", layer name " + layer.getName() + ")");
 | 
			
		||||
        throw new InvalidKerasConfigurationException(
 | 
			
		||||
            "Cannot add vertex to NeuralNetConfiguration (class name "
 | 
			
		||||
                + layer.getClassName()
 | 
			
		||||
                + ", layer name "
 | 
			
		||||
                + layer.getName()
 | 
			
		||||
                + ")");
 | 
			
		||||
      prevLayer = layer;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
 | 
			
		||||
    if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
 | 
			
		||||
            modelBuilder.backpropType(BackpropType.TruncatedBPTT)
 | 
			
		||||
      modelBuilder
 | 
			
		||||
          .backpropType(BackpropType.TruncatedBPTT)
 | 
			
		||||
          .tbpttFwdLength(truncatedBPTT)
 | 
			
		||||
          .tbpttBackLength(truncatedBPTT);
 | 
			
		||||
        else
 | 
			
		||||
            modelBuilder.backpropType(BackpropType.Standard);
 | 
			
		||||
    else modelBuilder.backpropType(BackpropType.Standard);
 | 
			
		||||
 | 
			
		||||
    NeuralNetConfiguration build = modelBuilder.build();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return build;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A fluent API to configure and create artificial neural networks
 | 
			
		||||
@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB
 | 
			
		||||
public class NN {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  public static NeuralNetConfigurationBuilder<?, ?> net() {
 | 
			
		||||
  public static NeuralNetConfigurationBuilder<?, ?> nn() {
 | 
			
		||||
    return NeuralNetConfiguration.builder();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
			
		||||
  @Getter @Setter @NonNull @lombok.Builder.Default
 | 
			
		||||
  protected BackpropType backpropType = BackpropType.Standard;
 | 
			
		||||
 | 
			
		||||
  @Getter @lombok.Builder.Default
 | 
			
		||||
  @Getter @Setter @Singular
 | 
			
		||||
  protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
 | 
			
		||||
  /**
 | 
			
		||||
   * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
 | 
			
		||||
@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
			
		||||
     * @param processor what to use to preProcess the data.
 | 
			
		||||
     * @return builder pattern
 | 
			
		||||
     */
 | 
			
		||||
    public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
 | 
			
		||||
      if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>();
 | 
			
		||||
      inputPreProcessors$value.put(layer, processor);
 | 
			
		||||
      inputPreProcessors$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
   //public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
 | 
			
		||||
   //   inputPreProcessors$value.put(layer, processor);
 | 
			
		||||
   //   inputPreProcessors$set = true;
 | 
			
		||||
   //   return self();
 | 
			
		||||
   // }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Set layer at index
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.databind.*;
 | 
			
		||||
import java.util.*;
 | 
			
		||||
import java.util.concurrent.atomic.AtomicInteger;
 | 
			
		||||
import java.util.stream.Collectors;
 | 
			
		||||
import lombok.*;
 | 
			
		||||
import lombok.experimental.SuperBuilder;
 | 
			
		||||
@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
			
		||||
          @NonNull
 | 
			
		||||
          InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
 | 
			
		||||
          if (inputPreProcessor != null) {
 | 
			
		||||
            inputPreProcessors = new HashMap<>(inputPreProcessors);
 | 
			
		||||
            inputPreProcessors.put(i, inputPreProcessor);
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
			
		||||
                    obj.getClass().getSimpleName());
 | 
			
		||||
              }
 | 
			
		||||
            });
 | 
			
		||||
    // make sure the indexes are sequenced properly
 | 
			
		||||
    AtomicInteger i = new AtomicInteger();
 | 
			
		||||
    ret.forEach(obj -> {
 | 
			
		||||
      obj.setIndex(i.getAndIncrement());
 | 
			
		||||
    });
 | 
			
		||||
    return ret;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
      throw new IllegalStateException(
 | 
			
		||||
          "Invalid input for Convolution layer (layer name=\""
 | 
			
		||||
              + getName()
 | 
			
		||||
              + "\"): Expected CNN input, got "
 | 
			
		||||
              + "\" at index '"+getIndex()+"') : Expected CNN input, got "
 | 
			
		||||
              + inputType);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
     * @param kernelSize kernel size
 | 
			
		||||
     */
 | 
			
		||||
    public B kernelSize(int... kernelSize) {
 | 
			
		||||
      this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
 | 
			
		||||
      //this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
 | 
			
		||||
      this.kernelSize$value = kernelSize;
 | 
			
		||||
      this.kernelSize$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
     * @param stride kernel size
 | 
			
		||||
     */
 | 
			
		||||
    public B stride(int... stride) {
 | 
			
		||||
      this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
 | 
			
		||||
      //this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
 | 
			
		||||
      this.stride$value = stride;
 | 
			
		||||
      this.stride$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
     * @param padding kernel size
 | 
			
		||||
     */
 | 
			
		||||
    public B padding(int... padding) {
 | 
			
		||||
      this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
 | 
			
		||||
      //this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
 | 
			
		||||
      this.padding$value = padding;
 | 
			
		||||
      this.padding$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
     * @param dilation kernel size
 | 
			
		||||
     */
 | 
			
		||||
    public B dilation(int... dilation) {
 | 
			
		||||
      this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
 | 
			
		||||
      //this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
 | 
			
		||||
      this.dilation$value = dilation;
 | 
			
		||||
      this.dilation$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -20,14 +20,19 @@
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.nn.conf.layers;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Collection;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import java.util.stream.IntStream;
 | 
			
		||||
 | 
			
		||||
import lombok.*;
 | 
			
		||||
import lombok.experimental.SuperBuilder;
 | 
			
		||||
import lombok.extern.jackson.Jacksonized;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.deeplearning4j.nn.api.Layer;
 | 
			
		||||
import org.deeplearning4j.nn.api.ParamInitializer;
 | 
			
		||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
			
		||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
			
		||||
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
 | 
			
		||||
@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer {
 | 
			
		||||
      boolean initializeParams,
 | 
			
		||||
      DataType networkDataType) {
 | 
			
		||||
    setNetConfiguration(conf);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
 | 
			
		||||
    LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
			
		||||
    runInheritance();
 | 
			
		||||
@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer {
 | 
			
		||||
        getName(),
 | 
			
		||||
        Deconvolution2DLayer.class);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@Slf4j
 | 
			
		||||
  private static final class Deconvolution2DBuilderImpl
 | 
			
		||||
      extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
 | 
			
		||||
    public Deconvolution2D build() {
 | 
			
		||||
      Deconvolution2D l = new Deconvolution2D(this);
 | 
			
		||||
      if( l.getConvolutionMode() == ConvolutionMode.Same
 | 
			
		||||
              && IntStream.of(l.getPadding()).sum() != 0) {
 | 
			
		||||
        log.warn("Invalid input for layer '{}'. "
 | 
			
		||||
                + "You cannot have a padding of {} when Convolution Mode is set to 'Same'."
 | 
			
		||||
                + " Padding will be ignored."
 | 
			
		||||
        , l.getName(), l.getPadding());
 | 
			
		||||
      }
 | 
			
		||||
      /* strides * (input_size-1) + kernel_size - 2*padding */
 | 
			
		||||
      //TODO: This is wrong, also depends on convolutionMode, etc ...
 | 
			
		||||
      /*l.nOut = l.getStride()[0] * (l.getNIn()-1)
 | 
			
		||||
              + IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b)
 | 
			
		||||
              - 2L * IntStream.of(l.getPadding()).sum();
 | 
			
		||||
              */
 | 
			
		||||
      //l.nOut =264;
 | 
			
		||||
      l.initializeConstraints();
 | 
			
		||||
      return l;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -62,6 +62,7 @@ public abstract class LayerConfiguration
 | 
			
		||||
    implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
 | 
			
		||||
 | 
			
		||||
  @Getter @Setter protected String name;
 | 
			
		||||
  @Getter @Setter private int index;
 | 
			
		||||
  @Getter @Setter protected List<LayerConstraint> allParamConstraints;
 | 
			
		||||
  @Getter @Setter protected List<LayerConstraint> weightConstraints;
 | 
			
		||||
  @Getter @Setter protected List<LayerConstraint> biasConstraints;
 | 
			
		||||
@ -72,6 +73,7 @@ public abstract class LayerConfiguration
 | 
			
		||||
  /** The type of the layer, basically defines the base class and its properties */
 | 
			
		||||
  @Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Number of parameters this layer has a result of its configuration
 | 
			
		||||
   * @return number or parameters
 | 
			
		||||
@ -80,7 +82,6 @@ public abstract class LayerConfiguration
 | 
			
		||||
    return initializer().numParams(this);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * A reference to the neural net configuration. This field is excluded from json serialization as
 | 
			
		||||
   * well as from equals check to avoid circular referenced.
 | 
			
		||||
 | 
			
		||||
@ -53,13 +53,20 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
 | 
			
		||||
   * @param numChannels the channels
 | 
			
		||||
   */
 | 
			
		||||
  @JsonCreator
 | 
			
		||||
    public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight,
 | 
			
		||||
                    @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) {
 | 
			
		||||
  public FeedForwardToCnnPreProcessor(
 | 
			
		||||
      @JsonProperty("inputHeight") long inputHeight,
 | 
			
		||||
      @JsonProperty("inputWidth") long inputWidth,
 | 
			
		||||
      @JsonProperty("numChannels") long numChannels) {
 | 
			
		||||
    this.inputHeight = inputHeight;
 | 
			
		||||
    this.inputWidth = inputWidth;
 | 
			
		||||
    this.numChannels = numChannels;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Reshape to a channels x rows x columns tensor
 | 
			
		||||
   *
 | 
			
		||||
   * @param inputHeight the columns
 | 
			
		||||
   * @param inputWidth the rows
 | 
			
		||||
   */
 | 
			
		||||
  public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
 | 
			
		||||
    this.inputHeight = inputHeight;
 | 
			
		||||
    this.inputWidth = inputWidth;
 | 
			
		||||
@ -69,18 +76,24 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
 | 
			
		||||
  @Override
 | 
			
		||||
  public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
 | 
			
		||||
    this.shape = input.shape();
 | 
			
		||||
        if (input.rank() == 4)
 | 
			
		||||
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
 | 
			
		||||
    if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
 | 
			
		||||
 | 
			
		||||
    if (input.columns() != inputWidth * inputHeight * numChannels)
 | 
			
		||||
            throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows "
 | 
			
		||||
                    + inputHeight + " x columns " + inputWidth + " x channels " + numChannels
 | 
			
		||||
                    + " but was instead " + Arrays.toString(input.shape()));
 | 
			
		||||
      throw new IllegalArgumentException(
 | 
			
		||||
          "Invalid input: expect output columns must be equal to rows "
 | 
			
		||||
              + inputHeight
 | 
			
		||||
              + " x columns "
 | 
			
		||||
              + inputWidth
 | 
			
		||||
              + " x channels "
 | 
			
		||||
              + numChannels
 | 
			
		||||
              + " but was instead "
 | 
			
		||||
              + Arrays.toString(input.shape()));
 | 
			
		||||
 | 
			
		||||
    if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
 | 
			
		||||
      input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
 | 
			
		||||
 | 
			
		||||
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS,
 | 
			
		||||
    return workspaceMgr.leverageTo(
 | 
			
		||||
        ArrayType.ACTIVATIONS,
 | 
			
		||||
        input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -100,13 +113,11 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
 | 
			
		||||
    return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public FeedForwardToCnnPreProcessor clone() {
 | 
			
		||||
    try {
 | 
			
		||||
      FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
 | 
			
		||||
            if (clone.shape != null)
 | 
			
		||||
                clone.shape = clone.shape.clone();
 | 
			
		||||
      if (clone.shape != null) clone.shape = clone.shape.clone();
 | 
			
		||||
      return clone;
 | 
			
		||||
    } catch (CloneNotSupportedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
@ -121,26 +132,60 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
 | 
			
		||||
        InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
 | 
			
		||||
        val expSize = inputHeight * inputWidth * numChannels;
 | 
			
		||||
        if (c.getSize() != expSize) {
 | 
			
		||||
                    throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
 | 
			
		||||
                                    + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
 | 
			
		||||
          throw new IllegalStateException(
 | 
			
		||||
              "Invalid input: expected FeedForward input of size "
 | 
			
		||||
                  + expSize
 | 
			
		||||
                  + " = (d="
 | 
			
		||||
                  + numChannels
 | 
			
		||||
                  + " * w="
 | 
			
		||||
                  + inputWidth
 | 
			
		||||
                  + " * h="
 | 
			
		||||
                  + inputHeight
 | 
			
		||||
                  + "), got "
 | 
			
		||||
                  + inputType);
 | 
			
		||||
        }
 | 
			
		||||
        return InputType.convolutional(inputHeight, inputWidth, numChannels);
 | 
			
		||||
      case CNN:
 | 
			
		||||
        InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
 | 
			
		||||
 | 
			
		||||
                if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
 | 
			
		||||
                    throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
 | 
			
		||||
                                    + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
 | 
			
		||||
                                    + "," + inputHeight + "," + inputWidth + ")");
 | 
			
		||||
        if (c2.getChannels() != numChannels
 | 
			
		||||
            || c2.getHeight() != inputHeight
 | 
			
		||||
            || c2.getWidth() != inputWidth) {
 | 
			
		||||
          throw new IllegalStateException(
 | 
			
		||||
              "Invalid input: Got CNN input type with (d,w,h)=("
 | 
			
		||||
                  + c2.getChannels()
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + c2.getWidth()
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + c2.getHeight()
 | 
			
		||||
                  + ") but expected ("
 | 
			
		||||
                  + numChannels
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + inputHeight
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + inputWidth
 | 
			
		||||
                  + ")");
 | 
			
		||||
        }
 | 
			
		||||
        return c2;
 | 
			
		||||
      case CNNFlat:
 | 
			
		||||
        InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
 | 
			
		||||
                if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
 | 
			
		||||
                    throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
 | 
			
		||||
                                    + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
 | 
			
		||||
                                    + "," + inputHeight + "," + inputWidth + ")");
 | 
			
		||||
        if (c3.getDepth() != numChannels
 | 
			
		||||
            || c3.getHeight() != inputHeight
 | 
			
		||||
            || c3.getWidth() != inputWidth) {
 | 
			
		||||
          throw new IllegalStateException(
 | 
			
		||||
              "Invalid input: Got CNN input type with (d,w,h)=("
 | 
			
		||||
                  + c3.getDepth()
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + c3.getWidth()
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + c3.getHeight()
 | 
			
		||||
                  + ") but expected ("
 | 
			
		||||
                  + numChannels
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + inputHeight
 | 
			
		||||
                  + ","
 | 
			
		||||
                  + inputWidth
 | 
			
		||||
                  + ")");
 | 
			
		||||
        }
 | 
			
		||||
        return c3.getUnflattenedType();
 | 
			
		||||
      default:
 | 
			
		||||
@ -149,10 +194,9 @@ public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
 | 
			
		||||
                    int minibatchSize) {
 | 
			
		||||
  public Pair<INDArray, MaskState> feedForwardMaskArray(
 | 
			
		||||
      INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
 | 
			
		||||
    // Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
 | 
			
		||||
    return new Pair<>(maskArray, currentMaskState);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -369,7 +369,7 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
 | 
			
		||||
 | 
			
		||||
  protected String layerId() {
 | 
			
		||||
    String name = this.layerConfiguration.getName();
 | 
			
		||||
    return "(layer name: "
 | 
			
		||||
    return "(network: " + getNetConfiguration().getName() + " layer name: "
 | 
			
		||||
        + (name == null ? "\"\"" : name)
 | 
			
		||||
        + ", layer index: "
 | 
			
		||||
        + index
 | 
			
		||||
 | 
			
		||||
@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
 | 
			
		||||
 | 
			
		||||
        int[] args = new int[] {
 | 
			
		||||
                (int)kH, (int)kW, strides[0], strides[1],
 | 
			
		||||
                pad[0], pad[1], dilation[0], dilation[1], sameMode,
 | 
			
		||||
                nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
			
		||||
                pad[0], pad[1], dilation[0], dilation[1], sameMode //,
 | 
			
		||||
                //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        INDArray delta;
 | 
			
		||||
@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
 | 
			
		||||
 | 
			
		||||
        int[] args = new int[] {
 | 
			
		||||
                kH, kW, strides[0], strides[1],
 | 
			
		||||
                pad[0], pad[1], dilation[0], dilation[1], sameMode,
 | 
			
		||||
                nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
			
		||||
                pad[0], pad[1], dilation[0], dilation[1], sameMode //,
 | 
			
		||||
                //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        //DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
 | 
			
		||||
@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
 | 
			
		||||
        } else {
 | 
			
		||||
            opInputs = new INDArray[]{input, weights};
 | 
			
		||||
        }
 | 
			
		||||
        /**
 | 
			
		||||
         * 2D deconvolution implementation
 | 
			
		||||
         *
 | 
			
		||||
         * IntArgs:
 | 
			
		||||
         * 0: kernel height
 | 
			
		||||
         * 1: kernel width
 | 
			
		||||
         * 2: stride height
 | 
			
		||||
         * 3: stride width
 | 
			
		||||
         * 4: padding height
 | 
			
		||||
         * 5: padding width
 | 
			
		||||
         * 6: dilation height
 | 
			
		||||
         * 7: dilation width
 | 
			
		||||
         * 8: same mode: 0 false, 1 true
 | 
			
		||||
         */
 | 
			
		||||
        CustomOp op = DynamicCustomOp.builder("deconv2d")
 | 
			
		||||
                .addInputs(opInputs)
 | 
			
		||||
                .addIntegerArguments(args)
 | 
			
		||||
 | 
			
		||||
@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
 | 
			
		||||
        LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
 | 
			
		||||
        layers[i] =
 | 
			
		||||
            lc.instantiate(
 | 
			
		||||
                lc.getNetConfiguration(),
 | 
			
		||||
                this.getNetConfiguration(),
 | 
			
		||||
                trainingListeners,
 | 
			
		||||
                i,
 | 
			
		||||
                paramsView,
 | 
			
		||||
 | 
			
		||||
@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
 | 
			
		||||
 | 
			
		||||
            params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(GAMMA);
 | 
			
		||||
            conf.addVariable(GAMMA);
 | 
			
		||||
            params.put(BETA, createBeta(conf, betaView, initializeParams));
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(BETA);
 | 
			
		||||
            conf.addVariable(BETA);
 | 
			
		||||
 | 
			
		||||
            meanOffset = 2 * nOut;
 | 
			
		||||
        }
 | 
			
		||||
@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
 | 
			
		||||
 | 
			
		||||
        params.put(GLOBAL_MEAN, globalMeanView);
 | 
			
		||||
        conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
 | 
			
		||||
        conf.addVariable(GLOBAL_MEAN);
 | 
			
		||||
        if(layer.isUseLogStd()){
 | 
			
		||||
            params.put(GLOBAL_LOG_STD, globalVarView);
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
 | 
			
		||||
            conf.addVariable(GLOBAL_LOG_STD);
 | 
			
		||||
        } else {
 | 
			
		||||
            params.put(GLOBAL_VAR, globalVarView);
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
 | 
			
		||||
            conf.addVariable(GLOBAL_VAR);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return params;
 | 
			
		||||
 | 
			
		||||
@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
 | 
			
		||||
            params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
 | 
			
		||||
            conf.addVariable(WEIGHT_KEY);
 | 
			
		||||
            conf.addVariable(BIAS_KEY);
 | 
			
		||||
        } else {
 | 
			
		||||
            INDArray weightView = paramsView;
 | 
			
		||||
            params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
 | 
			
		||||
            conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
 | 
			
		||||
            conf.addVariable(WEIGHT_KEY);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return params;
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true
 | 
			
		||||
 | 
			
		||||
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
 | 
			
		||||
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
 | 
			
		||||
 | 
			
		||||
#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true
 | 
			
		||||
 | 
			
		||||
# When configured, Gradle will run in incubating parallel mode.
 | 
			
		||||
# This option should only be used with decoupled projects. More details, visit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										16
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							@ -116,6 +116,7 @@ esac
 | 
			
		||||
 | 
			
		||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Determine the Java command to use to start the JVM.
 | 
			
		||||
if [ -n "$JAVA_HOME" ] ; then
 | 
			
		||||
    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
 | 
			
		||||
@ -144,14 +145,12 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then
 | 
			
		||||
      max*)
 | 
			
		||||
        MAX_FD=$( ulimit -H -n ) ||
 | 
			
		||||
            warn "Could not query maximum file descriptor limit"
 | 
			
		||||
    ;;
 | 
			
		||||
    esac
 | 
			
		||||
    case $MAX_FD in  #(
 | 
			
		||||
      '' | soft) :;; #(
 | 
			
		||||
      *)
 | 
			
		||||
        ulimit -n "$MAX_FD" ||
 | 
			
		||||
            warn "Could not set maximum file descriptor limit to $MAX_FD"
 | 
			
		||||
    ;;
 | 
			
		||||
    esac
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
@ -171,14 +170,12 @@ if "$cygwin" || "$msys"; then
 | 
			
		||||
    JAVACMD=$( cygpath --unix "$JAVACMD" )
 | 
			
		||||
 | 
			
		||||
    # Now convert the arguments - kludge to limit ourselves to /bin/sh
 | 
			
		||||
  for arg; do
 | 
			
		||||
    for arg do
 | 
			
		||||
        if
 | 
			
		||||
            case $arg in                                #(
 | 
			
		||||
              -*)   false ;;                            # don't mess with options #(
 | 
			
		||||
      /?*)
 | 
			
		||||
        t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
 | 
			
		||||
        [ -e "$t" ]
 | 
			
		||||
        ;; #(
 | 
			
		||||
              /?*)  t=${arg#/} t=/${t%%/*}              # looks like a POSIX filepath
 | 
			
		||||
                    [ -e "$t" ] ;;                      #(
 | 
			
		||||
              *)    false ;;
 | 
			
		||||
            esac
 | 
			
		||||
        then
 | 
			
		||||
@ -208,11 +205,6 @@ set -- \
 | 
			
		||||
        org.gradle.wrapper.GradleWrapperMain \
 | 
			
		||||
        "$@"
 | 
			
		||||
 | 
			
		||||
# Stop when "xargs" is not available.
 | 
			
		||||
if ! command -v xargs >/dev/null 2>&1; then
 | 
			
		||||
  die "xargs is not available"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Use "xargs" to parse quoted args.
 | 
			
		||||
#
 | 
			
		||||
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
 | 
			
		||||
 | 
			
		||||
set JAVA_EXE=java.exe
 | 
			
		||||
%JAVA_EXE% -version >NUL 2>&1
 | 
			
		||||
if %ERRORLEVEL% equ 0 goto execute
 | 
			
		||||
if "%ERRORLEVEL%" == "0" goto execute
 | 
			
		||||
 | 
			
		||||
echo.
 | 
			
		||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
			
		||||
@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
 | 
			
		||||
 | 
			
		||||
:end
 | 
			
		||||
@rem End local scope for the variables with windows NT shell
 | 
			
		||||
if %ERRORLEVEL% equ 0 goto mainEnd
 | 
			
		||||
if "%ERRORLEVEL%"=="0" goto mainEnd
 | 
			
		||||
 | 
			
		||||
:fail
 | 
			
		||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
 | 
			
		||||
rem the _cmd.exe /c_ return code!
 | 
			
		||||
set EXIT_CODE=%ERRORLEVEL%
 | 
			
		||||
if %EXIT_CODE% equ 0 set EXIT_CODE=1
 | 
			
		||||
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
 | 
			
		||||
exit /b %EXIT_CODE%
 | 
			
		||||
if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
 | 
			
		||||
exit /b 1
 | 
			
		||||
 | 
			
		||||
:mainEnd
 | 
			
		||||
if "%OS%"=="Windows_NT" endlocal
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user