gan example
Signed-off-by: brian <brian@brutex.de>
This commit is contained in:
		
							parent
							
								
									1b3338f809
								
							
						
					
					
						commit
						dd151aec3f
					
				@ -1,473 +1,261 @@
 | 
				
			|||||||
/*
 | 
					 | 
				
			||||||
 *
 | 
					 | 
				
			||||||
 *    ******************************************************************************
 | 
					 | 
				
			||||||
 *    *
 | 
					 | 
				
			||||||
 *    * This program and the accompanying materials are made available under the
 | 
					 | 
				
			||||||
 *    * terms of the Apache License, Version 2.0 which is available at
 | 
					 | 
				
			||||||
 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					 | 
				
			||||||
 *    *
 | 
					 | 
				
			||||||
 *    *  See the NOTICE file distributed with this work for additional
 | 
					 | 
				
			||||||
 *    *  information regarding copyright ownership.
 | 
					 | 
				
			||||||
 *    * Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
					 | 
				
			||||||
 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
					 | 
				
			||||||
 *    * License for the specific language governing permissions and limitations
 | 
					 | 
				
			||||||
 *    * under the License.
 | 
					 | 
				
			||||||
 *    *
 | 
					 | 
				
			||||||
 *    * SPDX-License-Identifier: Apache-2.0
 | 
					 | 
				
			||||||
 *    *****************************************************************************
 | 
					 | 
				
			||||||
 *
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package net.brutex.gan;
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import static net.brutex.ai.dnn.api.NN.dense;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.awt.*;
 | 
					import java.awt.*;
 | 
				
			||||||
import java.awt.image.BufferedImage;
 | 
					import java.awt.image.BufferedImage;
 | 
				
			||||||
import java.io.File;
 | 
					import java.io.File;
 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.util.Arrays;
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.Random;
 | 
					import javax.swing.*;
 | 
				
			||||||
import java.util.UUID;
 | 
					 | 
				
			||||||
import javax.imageio.ImageIO;
 | 
					 | 
				
			||||||
import javax.swing.ImageIcon;
 | 
					 | 
				
			||||||
import javax.swing.JFrame;
 | 
					 | 
				
			||||||
import javax.swing.JLabel;
 | 
					 | 
				
			||||||
import javax.swing.JPanel;
 | 
					 | 
				
			||||||
import javax.swing.WindowConstants;
 | 
					 | 
				
			||||||
import lombok.extern.slf4j.Slf4j;
 | 
					 | 
				
			||||||
import org.apache.commons.lang3.ArrayUtils;
 | 
					import org.apache.commons.lang3.ArrayUtils;
 | 
				
			||||||
import org.datavec.api.split.FileSplit;
 | 
					 | 
				
			||||||
import org.datavec.image.loader.NativeImageLoader;
 | 
					 | 
				
			||||||
import org.datavec.image.recordreader.ImageRecordReader;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ColorConversionTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ImageTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.PipelineImageTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ResizeImageTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ShowImageTransform;
 | 
					 | 
				
			||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
					 | 
				
			||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
					import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
					import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
 | 
					 | 
				
			||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
					import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
				
			||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
 | 
					 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
					import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.dataset.DataSet;
 | 
					import org.nd4j.linalg.dataset.DataSet;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.factory.Nd4j;
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.Adam;
 | 
					import org.nd4j.linalg.learning.config.Adam;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
 | 
					import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Slf4j
 | 
					 | 
				
			||||||
public class App {
 | 
					public class App {
 | 
				
			||||||
  private static final double LEARNING_RATE = 0.000002;
 | 
					    private static final double LEARNING_RATE = 0.002;
 | 
				
			||||||
  private static final double GRADIENT_THRESHOLD = 100.0;
 | 
					    private static final double GRADIENT_THRESHOLD = 100.0;
 | 
				
			||||||
 | 
					    private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
 | 
				
			||||||
 | 
					    private static final int BATCHSIZE = 128;
 | 
				
			||||||
 | 
					    private static JFrame frame;
 | 
				
			||||||
 | 
					    private static JPanel panel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static final int X_DIM = 20 ;
 | 
					    private static LayerConfiguration[] genLayers() {
 | 
				
			||||||
  private static final int Y_DIM = 20;
 | 
					        return new LayerConfiguration[] {
 | 
				
			||||||
  private static final int CHANNELS = 1;
 | 
					                dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
 | 
				
			||||||
  private static final int batchSize = 1;
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
  private static final int INPUT = 10;
 | 
					                dense().nIn(256).nOut(512).build(),
 | 
				
			||||||
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					                dense().nIn(512).nOut(1024).build(),
 | 
				
			||||||
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					                dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static final int OUTPUT_PER_PANEL = 16;
 | 
					    /**
 | 
				
			||||||
 | 
					     * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @return config
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    private static NeuralNetConfiguration generator() {
 | 
				
			||||||
 | 
					        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                .seed(42)
 | 
				
			||||||
 | 
					                .updater(UPDATER)
 | 
				
			||||||
 | 
					                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
				
			||||||
 | 
					                .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
 | 
					                .activation(Activation.IDENTITY)
 | 
				
			||||||
 | 
					                .layersFromArray(genLayers())
 | 
				
			||||||
 | 
					                .name("generator")
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
 | 
					        return conf;
 | 
				
			||||||
  private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private  static JFrame frame;
 | 
					    private static LayerConfiguration[] disLayers() {
 | 
				
			||||||
  private static  JFrame frame2;
 | 
					        return new LayerConfiguration[]{
 | 
				
			||||||
  private static JPanel panel;
 | 
					                dense().nIn(784).nOut(1024).build(),
 | 
				
			||||||
  private static JPanel panel2;
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					                DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					                dense().nIn(1024).nOut(512).build(),
 | 
				
			||||||
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					                DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					                dense().nIn(512).nOut(256).build(),
 | 
				
			||||||
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					                DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					                OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static final String OUTPUT_DIR = "C:/temp/output/";
 | 
					    private static NeuralNetConfiguration discriminator() {
 | 
				
			||||||
 | 
					        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                .seed(42)
 | 
				
			||||||
 | 
					                .updater(UPDATER)
 | 
				
			||||||
 | 
					                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
				
			||||||
 | 
					                .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
 | 
					                .activation(Activation.IDENTITY)
 | 
				
			||||||
 | 
					                .layersFromArray(disLayers())
 | 
				
			||||||
 | 
					                .name("discriminator")
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static LayerConfiguration[] genLayers() {
 | 
					        return conf;
 | 
				
			||||||
    return new LayerConfiguration[] {
 | 
					    }
 | 
				
			||||||
        DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
 | 
					 | 
				
			||||||
        ActivationLayer.builder(Activation.LEAKYRELU).build(),
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
					    private static NeuralNetConfiguration gan() {
 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					        LayerConfiguration[] genLayers = genLayers();
 | 
				
			||||||
            DropoutLayer.builder(1 - 0.5).build(),
 | 
					        LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
 | 
					                .map((layer) -> {
 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
				
			||||||
 | 
					                        return FrozenLayerWithBackprop.builder(layer).build();
 | 
				
			||||||
 | 
					                    } else {
 | 
				
			||||||
 | 
					                        return layer;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }).toArray(LayerConfiguration[]::new);
 | 
				
			||||||
 | 
					        LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
 | 
					        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
				
			||||||
    };
 | 
					                .seed(42)
 | 
				
			||||||
 }
 | 
					                .updater(UPDATER)
 | 
				
			||||||
 | 
					                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
				
			||||||
 | 
					                .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
 | 
					                .activation(Activation.IDENTITY)
 | 
				
			||||||
 | 
					                .layersFromArray(layers)
 | 
				
			||||||
 | 
					                .name("GAN")
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					        return conf;
 | 
				
			||||||
   * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
 | 
					    }
 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @return config
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  private static NeuralNetConfiguration generator() {
 | 
					 | 
				
			||||||
    NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
					 | 
				
			||||||
        .seed(42)
 | 
					 | 
				
			||||||
        .updater(UPDATER)
 | 
					 | 
				
			||||||
        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
					 | 
				
			||||||
        .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
					 | 
				
			||||||
        //.weightInit(WeightInit.XAVIER)
 | 
					 | 
				
			||||||
        .weightInit(WeightInit.XAVIER)
 | 
					 | 
				
			||||||
        .activation(Activation.IDENTITY)
 | 
					 | 
				
			||||||
        .layersFromArray(genLayers())
 | 
					 | 
				
			||||||
        .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
					 | 
				
			||||||
       // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
 | 
					 | 
				
			||||||
        .build();
 | 
					 | 
				
			||||||
    ((NeuralNetConfiguration) conf).init();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
   return conf;
 | 
					    @Test
 | 
				
			||||||
  }
 | 
					    public void runTest() throws Exception {
 | 
				
			||||||
 | 
					        App.main(null);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    public static void main(String... args) throws Exception {
 | 
				
			||||||
 | 
					        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static LayerConfiguration[] disLayers() {
 | 
					        MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
 | 
				
			||||||
    return new LayerConfiguration[]{
 | 
					 | 
				
			||||||
        DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
 | 
					 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					 | 
				
			||||||
        DropoutLayer.builder(1 - 0.5).build(),
 | 
					 | 
				
			||||||
        DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
 | 
					 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					 | 
				
			||||||
        DropoutLayer.builder(1 - 0.5).build(),
 | 
					 | 
				
			||||||
        DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
 | 
					 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					 | 
				
			||||||
        DropoutLayer.builder(1 - 0.5).build(),
 | 
					 | 
				
			||||||
        DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
					 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					 | 
				
			||||||
        DropoutLayer.builder(1 - 0.5).build(),
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        OutputLayer.builder().name("dis-output").lossFunction(LossFunction.MCXENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
 | 
					        MultiLayerNetwork gen = new MultiLayerNetwork(generator());
 | 
				
			||||||
    };
 | 
					        MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
 | 
				
			||||||
  }
 | 
					        MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
				
			||||||
 | 
					        gen.init();
 | 
				
			||||||
 | 
					        dis.init();
 | 
				
			||||||
 | 
					        gan.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static NeuralNetConfiguration discriminator() {
 | 
					        copyParams(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    NeuralNetConfiguration conf =
 | 
					        gen.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
        NeuralNetConfiguration.builder()
 | 
					        dis.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
            .seed(42)
 | 
					        gan.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
            .updater(UPDATER)
 | 
					 | 
				
			||||||
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
					 | 
				
			||||||
            .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
					 | 
				
			||||||
            .weightInit(WeightInit.XAVIER)
 | 
					 | 
				
			||||||
            //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
					 | 
				
			||||||
                .weightNoise(null)
 | 
					 | 
				
			||||||
            // .weightInitFn(new WeightInitXavier())
 | 
					 | 
				
			||||||
            // .activationFn(new ActivationIdentity())
 | 
					 | 
				
			||||||
            .activation(Activation.IDENTITY)
 | 
					 | 
				
			||||||
            .layersFromArray(disLayers())
 | 
					 | 
				
			||||||
            .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
					 | 
				
			||||||
            .build();
 | 
					 | 
				
			||||||
    ((NeuralNetConfiguration) conf).init();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return conf;
 | 
					        trainData.reset();
 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static NeuralNetConfiguration gan() {
 | 
					        int j = 0;
 | 
				
			||||||
    LayerConfiguration[] genLayers = genLayers();
 | 
					        for (int i = 0; i < 50; i++) {
 | 
				
			||||||
    LayerConfiguration[] disLayers = Arrays.stream(disLayers())
 | 
					            while (trainData.hasNext()) {
 | 
				
			||||||
        .map((layer) -> {
 | 
					                j++;
 | 
				
			||||||
         if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
					 | 
				
			||||||
          return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
 | 
					 | 
				
			||||||
          } else {
 | 
					 | 
				
			||||||
            return layer;
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        }).toArray(LayerConfiguration[]::new);
 | 
					 | 
				
			||||||
    LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
					                // generate data
 | 
				
			||||||
        .seed(42)
 | 
					                INDArray real = trainData.next().getFeatures().muli(2).subi(1);
 | 
				
			||||||
        .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
 | 
					                int batchSize = (int) real.shape()[0];
 | 
				
			||||||
        .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
 | 
					
 | 
				
			||||||
        .gradientNormalizationThreshold( 100 )
 | 
					                INDArray fakeIn = Nd4j.rand(batchSize, 100);
 | 
				
			||||||
        //.weightInitFn( new WeightInitXavier() ) //this is internal
 | 
					                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
				
			||||||
            .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
					
 | 
				
			||||||
        .weightInit( WeightInit.XAVIER)
 | 
					                DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
 | 
				
			||||||
        //.activationFn( new ActivationIdentity()) //this is internal
 | 
					                DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
 | 
				
			||||||
        .activation( Activation.IDENTITY )
 | 
					
 | 
				
			||||||
        .layersFromArray(  layers  )
 | 
					                DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
				
			||||||
        .inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
					
 | 
				
			||||||
            .dataType(DataType.FLOAT)
 | 
					                dis.fit(data);
 | 
				
			||||||
        .build();
 | 
					                dis.fit(data);
 | 
				
			||||||
((NeuralNetConfiguration) conf).init();
 | 
					
 | 
				
			||||||
    return conf;
 | 
					                // Update the discriminator in the GAN network
 | 
				
			||||||
  }
 | 
					                updateGan(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					                if (j % 10 == 1) {
 | 
				
			||||||
  public void runTest() throws Exception {
 | 
					                    System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
 | 
				
			||||||
      if(! log.isDebugEnabled()) {
 | 
					                    INDArray[] samples = new INDArray[9];
 | 
				
			||||||
          log.info("Logging is not set to DEBUG");
 | 
					                    DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
      else {
 | 
					 | 
				
			||||||
          log.info("Logging is set to DEBUG");
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    main();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public static void main(String... args) throws Exception {
 | 
					                    for (int k = 0; k < 9; k++) {
 | 
				
			||||||
 | 
					                        INDArray input = fakeSet2.get(k).getFeatures();
 | 
				
			||||||
 | 
					                        //samples[k] = gen.output(input, false);
 | 
				
			||||||
 | 
					                        samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    log.info("\u001B[32m  Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m   ");
 | 
					                    }
 | 
				
			||||||
    Nd4j.getMemoryManager().setAutoGcWindow(500);
 | 
					                    visualize(samples);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
   //MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
 | 
					            }
 | 
				
			||||||
   //FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
 | 
					            trainData.reset();
 | 
				
			||||||
   FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
 | 
					            // Copy the GANs generator to gen.
 | 
				
			||||||
 | 
					            //updateGen(gen, gan);
 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageTransform transform2 = new ShowImageTransform("Tester", 30);
 | 
					 | 
				
			||||||
    ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageTransform tr = new PipelineImageTransform.Builder()
 | 
					 | 
				
			||||||
        //.addImageTransform(transform) //convert to GREY SCALE
 | 
					 | 
				
			||||||
        .addImageTransform(transform3)
 | 
					 | 
				
			||||||
        //.addImageTransform(transform2)
 | 
					 | 
				
			||||||
        .build();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
 | 
					 | 
				
			||||||
    imageRecordReader.initialize(fileSplit, tr);
 | 
					 | 
				
			||||||
    DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    MultiLayerNetwork gen = new MultiLayerNetwork(generator());
 | 
					 | 
				
			||||||
    MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
 | 
					 | 
				
			||||||
    MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
					 | 
				
			||||||
    gen.init(); log.debug("Generator network: {}", gen);
 | 
					 | 
				
			||||||
    dis.init(); log.debug("Discriminator network: {}", dis);
 | 
					 | 
				
			||||||
    gan.init(); log.info("Complete GAN network: {}", gan);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    copyParams(gen, dis, gan);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
 | 
					 | 
				
			||||||
    dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
					 | 
				
			||||||
    gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
					 | 
				
			||||||
    //gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
					 | 
				
			||||||
    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //System.out.println(gan.toString());
 | 
					 | 
				
			||||||
    //gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
 | 
					 | 
				
			||||||
    //trainData.reset();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    int j = 0;
 | 
					 | 
				
			||||||
    for (int i = 0; i < 51; i++) { //epoch
 | 
					 | 
				
			||||||
      while (trainData.hasNext()) {
 | 
					 | 
				
			||||||
        j++;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        DataSet next = trainData.next();
 | 
					 | 
				
			||||||
        // generate data
 | 
					 | 
				
			||||||
        INDArray real = next.getFeatures();//.div(255f);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        //start next round if there are not enough images left to have a full batchsize dataset
 | 
					 | 
				
			||||||
        if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
 | 
					 | 
				
			||||||
          log.warn("Your total number of input images is not a multiple of {}, "
 | 
					 | 
				
			||||||
              + "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //if(i%20 == 0) {
 | 
					 | 
				
			||||||
         frame2 = visualize(new INDArray[]{real}, batchSize,
 | 
					 | 
				
			||||||
         frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
 | 
					 | 
				
			||||||
        //}
 | 
					 | 
				
			||||||
       real.divi(255f);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//        int batchSize = (int) real.shape()[0];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
        //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
					 | 
				
			||||||
        fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        //log.info("real has {} items.", real.length());
 | 
					 | 
				
			||||||
        DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
 | 
					 | 
				
			||||||
        DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        dis.fit(data);
 | 
					 | 
				
			||||||
        //dis.fit(data);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Update the discriminator in the GAN network
 | 
					 | 
				
			||||||
        updateGan(gen, dis, gan);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
 | 
					 | 
				
			||||||
        gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        //Visualize and reporting
 | 
					 | 
				
			||||||
        if (j % 10 == 1) {
 | 
					 | 
				
			||||||
          System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
 | 
					 | 
				
			||||||
          INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          for (int k = 0; k < samples.length; k++) {
 | 
					 | 
				
			||||||
            //INDArray input = fakeSet2.get(k).getFeatures();
 | 
					 | 
				
			||||||
            DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
 | 
					 | 
				
			||||||
            INDArray input = fakeSet2.get(k).getFeatures();
 | 
					 | 
				
			||||||
            input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            //samples[k] = gen.output(input, false);
 | 
					 | 
				
			||||||
            samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
					 | 
				
			||||||
            samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
            //samples[k] =
 | 
					 | 
				
			||||||
            samples[k].addi(1f).divi(2f).muli(255f);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      if (trainData.resetSupported()) {
 | 
					 | 
				
			||||||
          trainData.reset();
 | 
					 | 
				
			||||||
      } else {
 | 
					 | 
				
			||||||
          log.error("Trainingdata {} does not support reset.", trainData.toString());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
        // Copy the GANs generator to gen.
 | 
					        // Copy the GANs generator to gen.
 | 
				
			||||||
        updateGen(gen, gan);
 | 
					        updateGen(gen, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        gen.save(new File("mnist-mlp-generator.dlj"));
 | 
					        gen.save(new File("mnist-mlp-generator.dlj"));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
				
			||||||
 | 
					        int genLayerCount = gen.getLayers().length;
 | 
				
			||||||
 | 
					        for (int i = 0; i < gan.getLayers().length; i++) {
 | 
				
			||||||
  }
 | 
					            if (i < genLayerCount) {
 | 
				
			||||||
 | 
					                gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
  private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
					            } else {
 | 
				
			||||||
    int genLayerCount = gen.getLayers().length;
 | 
					                dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
					            }
 | 
				
			||||||
      if (i < genLayerCount) {
 | 
					 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					 | 
				
			||||||
         gan.getLayer(i).setParams(gen.getLayer(i).getParams());
 | 
					 | 
				
			||||||
      } else {
 | 
					 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					 | 
				
			||||||
        gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
					 | 
				
			||||||
    for (int i = 0; i < gen.getLayers().length; i++) {
 | 
					 | 
				
			||||||
      gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
					 | 
				
			||||||
    int genLayerCount = gen.getLayers().length;
 | 
					 | 
				
			||||||
    for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
					 | 
				
			||||||
      gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
 | 
					 | 
				
			||||||
    if (isOrig) {
 | 
					 | 
				
			||||||
      frame.setTitle("Viz Original");
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      frame.setTitle("Generated");
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
					 | 
				
			||||||
    frame.setLayout(new BorderLayout());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    JPanel panelx = new JPanel();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    panelx.setLayout(new GridLayout(4, 4, 8, 8));
 | 
					 | 
				
			||||||
    for (INDArray sample : samples) {
 | 
					 | 
				
			||||||
      for(int i = 0; i<batchElements; i++) {
 | 
					 | 
				
			||||||
        panelx.add(getImage(sample, i, isOrig));
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    frame.add(panelx, BorderLayout.CENTER);
 | 
					 | 
				
			||||||
    frame.setVisible(true);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    frame.revalidate();
 | 
					 | 
				
			||||||
    frame.setMinimumSize(new Dimension(300, 20));
 | 
					 | 
				
			||||||
    frame.pack();
 | 
					 | 
				
			||||||
    return frame;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
					 | 
				
			||||||
    final BufferedImage bi;
 | 
					 | 
				
			||||||
    if(CHANNELS>1) {
 | 
					 | 
				
			||||||
        bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
        bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    final int imageSize = X_DIM * Y_DIM;
 | 
					 | 
				
			||||||
    final int offset = batchElement * imageSize;
 | 
					 | 
				
			||||||
    int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //Image in NCHW - channels first format
 | 
					 | 
				
			||||||
    for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
					 | 
				
			||||||
      for (int y = 0; y < Y_DIM; y++) { // step through the columns x
 | 
					 | 
				
			||||||
        for (int x = 0; x < X_DIM; x++) { //step through the rows y
 | 
					 | 
				
			||||||
          if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
 | 
					 | 
				
			||||||
          bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
 | 
					 | 
				
			||||||
          pxl++; //next item in INDArray
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    ImageIcon orig = new ImageIcon(bi);
 | 
					 | 
				
			||||||
    Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
					 | 
				
			||||||
    ImageIcon scaled = new ImageIcon(imageScaled);
 | 
					 | 
				
			||||||
    if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
					 | 
				
			||||||
    return new JLabel(scaled);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  }
 | 
					    private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
				
			||||||
 | 
					        for (int i = 0; i < gen.getLayers().length; i++) {
 | 
				
			||||||
 | 
					            gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
					    private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
				
			||||||
    String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
					        int genLayerCount = gen.getLayers().length;
 | 
				
			||||||
 | 
					        for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
				
			||||||
 | 
					            gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try {
 | 
					    private static void visualize(INDArray[] samples) {
 | 
				
			||||||
      // Save the images to disk
 | 
					        if (frame == null) {
 | 
				
			||||||
      saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
					            frame = new JFrame();
 | 
				
			||||||
 | 
					            frame.setTitle("Viz");
 | 
				
			||||||
 | 
					            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
				
			||||||
 | 
					            frame.setLayout(new BorderLayout());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
     log.debug("Images saved successfully.");
 | 
					            panel = new JPanel();
 | 
				
			||||||
    } catch (IOException e) {
 | 
					
 | 
				
			||||||
      log.error("Error saving the images: {}", e.getMessage());
 | 
					            panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
 | 
				
			||||||
 | 
					            frame.add(panel, BorderLayout.CENTER);
 | 
				
			||||||
 | 
					            frame.setVisible(true);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        panel.removeAll();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (INDArray sample : samples) {
 | 
				
			||||||
 | 
					            panel.add(getImage(sample));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame.revalidate();
 | 
				
			||||||
 | 
					        frame.pack();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static JLabel getImage(INDArray tensor) {
 | 
				
			||||||
 | 
					        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
 | 
				
			||||||
 | 
					        for (int i = 0; i < 784; i++) {
 | 
				
			||||||
 | 
					            int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
 | 
				
			||||||
 | 
					            bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        ImageIcon orig = new ImageIcon(bi);
 | 
				
			||||||
 | 
					        Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return new JLabel(scaled);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
					 | 
				
			||||||
        File directory = new File(outputDirectory);
 | 
					 | 
				
			||||||
        if (!directory.exists()) {
 | 
					 | 
				
			||||||
            directory.mkdir();
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        File outputFile = new File(directory, fileName);
 | 
					 | 
				
			||||||
        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    public static BufferedImage imageToBufferedImage(Image image) {
 | 
					 | 
				
			||||||
        if (image instanceof BufferedImage) {
 | 
					 | 
				
			||||||
            return (BufferedImage) image;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Create a buffered image with the same dimensions and transparency as the original image
 | 
					 | 
				
			||||||
        BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Draw the original image onto the buffered image
 | 
					 | 
				
			||||||
        Graphics2D g2d = bufferedImage.createGraphics();
 | 
					 | 
				
			||||||
        g2d.drawImage(image, 0, 0, null);
 | 
					 | 
				
			||||||
        g2d.dispose();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return bufferedImage;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
							
								
								
									
										343
									
								
								brutex-extended-tests/src/test/java/net/brutex/gan/App2.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										343
									
								
								brutex-extended-tests/src/test/java/net/brutex/gan/App2.java
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,343 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *    ******************************************************************************
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *    * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    *  See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *    *  information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *    * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *    * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *    * under the License.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *    *****************************************************************************
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.awt.*;
 | 
				
			||||||
 | 
					import java.awt.image.BufferedImage;
 | 
				
			||||||
 | 
					import java.io.File;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.util.*;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import javax.imageio.ImageIO;
 | 
				
			||||||
 | 
					import javax.swing.*;
 | 
				
			||||||
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
 | 
					import org.datavec.api.split.FileSplit;
 | 
				
			||||||
 | 
					import org.datavec.image.loader.NativeImageLoader;
 | 
				
			||||||
 | 
					import org.datavec.image.recordreader.ImageRecordReader;
 | 
				
			||||||
 | 
					import org.datavec.image.transform.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
 | 
					import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
				
			||||||
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.dataset.DataSet;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Slf4j
 | 
				
			||||||
 | 
					public class App2 {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
 | 
				
			||||||
 | 
					    static final float COLORSPACE = 255f;
 | 
				
			||||||
 | 
					    static final int DIMENSIONS = 28;
 | 
				
			||||||
 | 
					    static final int CHANNELS = 1;
 | 
				
			||||||
 | 
					    final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
 | 
				
			||||||
 | 
					    final int OUTPUT_PER_PANEL = 10;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final boolean BIAS = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    static final int BATCHSIZE=128;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private JFrame frame2, frame;
 | 
				
			||||||
 | 
					    static final String OUTPUT_DIR = "d:/out/";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
 | 
				
			||||||
 | 
					    final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void runTest() throws IOException {
 | 
				
			||||||
 | 
					        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
 | 
				
			||||||
 | 
					        FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
 | 
				
			||||||
 | 
					        ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
 | 
				
			||||||
 | 
					        ImageTransform transform2 = new ShowImageTransform("Tester", 30);
 | 
				
			||||||
 | 
					        ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ImageTransform tr = new PipelineImageTransform.Builder()
 | 
				
			||||||
 | 
					               .addImageTransform(transform) //convert to GREY SCALE
 | 
				
			||||||
 | 
					                .addImageTransform(transform3)
 | 
				
			||||||
 | 
					                //.addImageTransform(transform2)
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
 | 
				
			||||||
 | 
					        imageRecordReader.initialize(fileSplit, tr);
 | 
				
			||||||
 | 
					        DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
 | 
				
			||||||
 | 
					        trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
 | 
				
			||||||
 | 
					        MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
 | 
				
			||||||
 | 
					                .map((layer) -> {
 | 
				
			||||||
 | 
					                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
				
			||||||
 | 
					                        return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
 | 
				
			||||||
 | 
					                    } else {
 | 
				
			||||||
 | 
					                        return layer;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }).toArray(LayerConfiguration[]::new);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    NeuralNetConfiguration netConfiguration =
 | 
				
			||||||
 | 
					        NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					            .name("GAN")
 | 
				
			||||||
 | 
					            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                .gradientNormalizationThreshold(100)
 | 
				
			||||||
 | 
					                .updater(App2Config.UPDATER)
 | 
				
			||||||
 | 
					            .innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
 | 
				
			||||||
 | 
					            .layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
 | 
				
			||||||
 | 
					            // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					            // .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					            //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					           // .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                //.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					                //.dataType(DataType.FLOAT)
 | 
				
			||||||
 | 
					            .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        dis.init(); log.debug("Discriminator network: {}", dis);
 | 
				
			||||||
 | 
					        gen.init(); log.debug("Generator network: {}", gen);
 | 
				
			||||||
 | 
					        gan.init(); log.debug("GAN network: {}", gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        log.info("Generator Summary:\n{}", gen.summary());
 | 
				
			||||||
 | 
					        log.info("GAN Summary:\n{}", gan.summary());
 | 
				
			||||||
 | 
					        dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
				
			||||||
 | 
					        gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
 | 
				
			||||||
 | 
					        gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        int j = 0;
 | 
				
			||||||
 | 
					        for (int i = 0; i < 51; i++) { //epoch
 | 
				
			||||||
 | 
					            while (trainData.hasNext()) {
 | 
				
			||||||
 | 
					                j++;
 | 
				
			||||||
 | 
					                DataSet next = trainData.next();
 | 
				
			||||||
 | 
					                // generate data
 | 
				
			||||||
 | 
					                INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //start next round if there are not enough images left to have a full batchsize dataset
 | 
				
			||||||
 | 
					                if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
 | 
				
			||||||
 | 
					                    log.warn("Your total number of input images is not a multiple of {}, "
 | 
				
			||||||
 | 
					                            + "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
 | 
				
			||||||
 | 
					                    break;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //if(i%20 == 0) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					               // frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
 | 
				
			||||||
 | 
					               //         frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
 | 
				
			||||||
 | 
					                //}
 | 
				
			||||||
 | 
					                //real.divi(255f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//        int batchSize = (int) real.shape()[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
 | 
				
			||||||
 | 
					                //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
 | 
				
			||||||
 | 
					                INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
				
			||||||
 | 
					                // when generator has TANH as activation - value range is -1 to 1
 | 
				
			||||||
 | 
					                // when generator has SIGMOID, then range is 0 to 1
 | 
				
			||||||
 | 
					                fake.addi(1f).divi(2f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                DataSet realSet = new DataSet(real, label_real);
 | 
				
			||||||
 | 
					                DataSet fakeSet = new DataSet(fake, label_fake);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                dis.fit(data);
 | 
				
			||||||
 | 
					                dis.fit(data);
 | 
				
			||||||
 | 
					                // Update the discriminator in the GAN network
 | 
				
			||||||
 | 
					                updateGan(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //Visualize and reporting
 | 
				
			||||||
 | 
					                if (j % 10 == 1) {
 | 
				
			||||||
 | 
					                    System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
 | 
				
			||||||
 | 
					                    INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    for (int k = 0; k < samples.length; k++) {
 | 
				
			||||||
 | 
					                        DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
 | 
				
			||||||
 | 
					                        INDArray input = fakeSet2.get(k).getFeatures();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        //input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
 | 
				
			||||||
 | 
					                        input = input.reshape(1, App2Config.INPUT);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        //samples[k] = gen.output(input, false);
 | 
				
			||||||
 | 
					                        samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
				
			||||||
 | 
					                        samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
 | 
				
			||||||
 | 
					                        //samples[k] =
 | 
				
			||||||
 | 
					                        //samples[k].muli(255f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (trainData.resetSupported()) {
 | 
				
			||||||
 | 
					                trainData.reset();
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                log.error("Trainingdata {} does not support reset.", trainData.toString());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            // Copy the GANs generator to gen.
 | 
				
			||||||
 | 
					            updateGen(gen, gan);
 | 
				
			||||||
 | 
					            log.info("Updated GAN's generator from gen.");
 | 
				
			||||||
 | 
					            gen.save(new File("mnist-mlp-generator.dlj"));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
 | 
				
			||||||
 | 
					        if (isOrig) {
 | 
				
			||||||
 | 
					            frame.setTitle("Viz Original");
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            frame.setTitle("Generated");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
				
			||||||
 | 
					        frame.setLayout(new BorderLayout());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        JPanel panelx = new JPanel();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        panelx.setLayout(new GridLayout(4, 4, 8, 8));
 | 
				
			||||||
 | 
					        for (INDArray sample : samples) {
 | 
				
			||||||
 | 
					            for(int i = 0; i<batchElements; i++) {
 | 
				
			||||||
 | 
					                panelx.add(getImage(sample, i, isOrig));
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        frame.add(panelx, BorderLayout.CENTER);
 | 
				
			||||||
 | 
					        frame.setVisible(true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame.revalidate();
 | 
				
			||||||
 | 
					        frame.setMinimumSize(new Dimension(300, 20));
 | 
				
			||||||
 | 
					        frame.pack();
 | 
				
			||||||
 | 
					        return frame;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
				
			||||||
 | 
					        final BufferedImage bi;
 | 
				
			||||||
 | 
					        if(CHANNELS >1) {
 | 
				
			||||||
 | 
					            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        final int imageSize = DIMENSIONS * DIMENSIONS;
 | 
				
			||||||
 | 
					        final int offset = batchElement * imageSize;
 | 
				
			||||||
 | 
					        int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //Image in NCHW - channels first format
 | 
				
			||||||
 | 
					        for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
				
			||||||
 | 
					            for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
 | 
				
			||||||
 | 
					                for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
 | 
				
			||||||
 | 
					                    float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
 | 
				
			||||||
 | 
					                    if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
 | 
				
			||||||
 | 
					                    bi.getRaster().setSample(x, y, c, f_pxl);
 | 
				
			||||||
 | 
					                    pxl++; //next item in INDArray
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        ImageIcon orig = new ImageIcon(bi);
 | 
				
			||||||
 | 
					        Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
 | 
				
			||||||
 | 
					        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
				
			||||||
 | 
					        if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
				
			||||||
 | 
					        return new JLabel(scaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
				
			||||||
 | 
					        String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            // Save the images to disk
 | 
				
			||||||
 | 
					            saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            log.debug("Images saved successfully.");
 | 
				
			||||||
 | 
					        } catch (IOException e) {
 | 
				
			||||||
 | 
					            log.error("Error saving the images: {}", e.getMessage());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
				
			||||||
 | 
					        File directory = new File(outputDirectory);
 | 
				
			||||||
 | 
					        if (!directory.exists()) {
 | 
				
			||||||
 | 
					            directory.mkdir();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        File outputFile = new File(directory, fileName);
 | 
				
			||||||
 | 
					        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public static BufferedImage imageToBufferedImage(Image image) {
 | 
				
			||||||
 | 
					        if (image instanceof BufferedImage) {
 | 
				
			||||||
 | 
					            return (BufferedImage) image;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Create a buffered image with the same dimensions and transparency as the original image
 | 
				
			||||||
 | 
					        BufferedImage bufferedImage;
 | 
				
			||||||
 | 
					    if (CHANNELS > 1) {
 | 
				
			||||||
 | 
					      bufferedImage =
 | 
				
			||||||
 | 
					          new BufferedImage(
 | 
				
			||||||
 | 
					              image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					        bufferedImage =
 | 
				
			||||||
 | 
					                new BufferedImage(
 | 
				
			||||||
 | 
					                        image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Draw the original image onto the buffered image
 | 
				
			||||||
 | 
					        Graphics2D g2d = bufferedImage.createGraphics();
 | 
				
			||||||
 | 
					        g2d.drawImage(image, 0, 0, null);
 | 
				
			||||||
 | 
					        g2d.dispose();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return bufferedImage;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
				
			||||||
 | 
					        for (int i = 0; i < gen.getLayers().length; i++) {
 | 
				
			||||||
 | 
					            gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
				
			||||||
 | 
					        int genLayerCount = gen.getLayers().length;
 | 
				
			||||||
 | 
					        for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
				
			||||||
 | 
					            gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,176 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *    ******************************************************************************
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *    * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    *  See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *    *  information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *    * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *    * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *    * under the License.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *    *****************************************************************************
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import static net.brutex.ai.dnn.api.NN.*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.learning.config.Adam;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					public class App2Config {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  public static final int INPUT = 100;
 | 
				
			||||||
 | 
					  public static final int X_DIM = 28;
 | 
				
			||||||
 | 
					  public static final int y_DIM = 28;
 | 
				
			||||||
 | 
					  public static final int CHANNELS = 1;
 | 
				
			||||||
 | 
					  public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static LayerConfiguration[] genLayerConfig() {
 | 
				
			||||||
 | 
					    return new LayerConfiguration[] {
 | 
				
			||||||
 | 
					            /*
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).build(), /*
 | 
				
			||||||
 | 
					                Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .kernelSize(2,2)
 | 
				
			||||||
 | 
					                        .stride(1,1)
 | 
				
			||||||
 | 
					                        .padding(0,0)
 | 
				
			||||||
 | 
					                        .convolutionMode(ConvolutionMode.Truncate)
 | 
				
			||||||
 | 
					                        .activation(Activation.RELU)
 | 
				
			||||||
 | 
					                        .hasBias(BIAS).build(),
 | 
				
			||||||
 | 
					                //BatchNormalization.builder().nOut(CHANNELS).build(),
 | 
				
			||||||
 | 
					                Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .kernelSize(2,2)
 | 
				
			||||||
 | 
					                        .stride(2,2)
 | 
				
			||||||
 | 
					                        .padding(0,0)
 | 
				
			||||||
 | 
					                        .convolutionMode(ConvolutionMode.Truncate)
 | 
				
			||||||
 | 
					                        .activation(Activation.RELU)
 | 
				
			||||||
 | 
					                        .hasBias(BIAS).build(),
 | 
				
			||||||
 | 
					               //BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
				
			||||||
 | 
					      // DropoutLayer.builder(0.001).build(),
 | 
				
			||||||
 | 
					      DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build()    */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            dense().nIn(256).nOut(512).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            dense().nIn(512).nOut(1024).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static LayerConfiguration[] disLayerConfig() {
 | 
				
			||||||
 | 
					    return new LayerConfiguration[] {/*
 | 
				
			||||||
 | 
					                Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .build(),
 | 
				
			||||||
 | 
					                Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .build(),
 | 
				
			||||||
 | 
					                ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
 | 
				
			||||||
 | 
					                BatchNormalization.builder().build(),
 | 
				
			||||||
 | 
					                OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                        .activation(Activation.SIGMOID)
 | 
				
			||||||
 | 
					                        .build()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(0.5).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(0.5).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(0.5).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
 | 
				
			||||||
 | 
					                    .activation(Activation.SIGMOID)
 | 
				
			||||||
 | 
					                    .build() */
 | 
				
			||||||
 | 
					            dense().nIn(784).nOut(1024).hasBias(true).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					            dense().nIn(1024).nOut(512).hasBias(true).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					            dense().nIn(512).nOut(256).hasBias(true).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					            OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static NeuralNetConfiguration generator() {
 | 
				
			||||||
 | 
					    NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					            NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .name("generator")
 | 
				
			||||||
 | 
					                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                    .gradientNormalizationThreshold(100)
 | 
				
			||||||
 | 
					                    .seed(42)
 | 
				
			||||||
 | 
					                    .updater(UPDATER)
 | 
				
			||||||
 | 
					                    .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
 | 
					                    //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
				
			||||||
 | 
					                    .weightNoise(null)
 | 
				
			||||||
 | 
					                    // .weightInitFn(new WeightInitXavier())
 | 
				
			||||||
 | 
					                    // .activationFn(new ActivationIdentity())
 | 
				
			||||||
 | 
					                    .activation(Activation.IDENTITY)
 | 
				
			||||||
 | 
					                    .layersFromArray(App2Config.genLayerConfig())
 | 
				
			||||||
 | 
					                    // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					    conf.init();
 | 
				
			||||||
 | 
					    return conf;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static NeuralNetConfiguration discriminator() {
 | 
				
			||||||
 | 
					    NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					            NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .name("discriminator")
 | 
				
			||||||
 | 
					                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                    .gradientNormalizationThreshold(100)
 | 
				
			||||||
 | 
					                    .seed(42)
 | 
				
			||||||
 | 
					                    .updater(UPDATER)
 | 
				
			||||||
 | 
					                    .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
 | 
					                    // .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
				
			||||||
 | 
					                    .weightNoise(null)
 | 
				
			||||||
 | 
					                    // .weightInitFn(new WeightInitXavier())
 | 
				
			||||||
 | 
					                    // .activationFn(new ActivationIdentity())
 | 
				
			||||||
 | 
					                    .activation(Activation.IDENTITY)
 | 
				
			||||||
 | 
					                    .layersFromArray(disLayerConfig())
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					                    //.dataType(DataType.FLOAT)
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					    conf.init();
 | 
				
			||||||
 | 
					    return conf;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -43,223 +43,255 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.impo
 | 
				
			|||||||
@Slf4j
 | 
					@Slf4j
 | 
				
			||||||
public class KerasSequentialModel extends KerasModel {
 | 
					public class KerasSequentialModel extends KerasModel {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					  /**
 | 
				
			||||||
     * (Recommended) Builder-pattern constructor for Sequential model.
 | 
					   * (Recommended) Builder-pattern constructor for Sequential model.
 | 
				
			||||||
     *
 | 
					   *
 | 
				
			||||||
     * @param modelBuilder builder object
 | 
					   * @param modelBuilder builder object
 | 
				
			||||||
     * @throws IOException                            I/O exception
 | 
					   * @throws IOException I/O exception
 | 
				
			||||||
     * @throws InvalidKerasConfigurationException     Invalid Keras configuration
 | 
					   * @throws InvalidKerasConfigurationException Invalid Keras configuration
 | 
				
			||||||
     * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
 | 
					   * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
 | 
				
			||||||
     */
 | 
					   */
 | 
				
			||||||
    public KerasSequentialModel(KerasModelBuilder modelBuilder)
 | 
					  public KerasSequentialModel(KerasModelBuilder modelBuilder)
 | 
				
			||||||
            throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
 | 
					      throws UnsupportedKerasConfigurationException,
 | 
				
			||||||
        this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
 | 
					          IOException,
 | 
				
			||||||
                modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
 | 
					          InvalidKerasConfigurationException {
 | 
				
			||||||
                modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
 | 
					    this(
 | 
				
			||||||
 | 
					        modelBuilder.getModelJson(),
 | 
				
			||||||
 | 
					        modelBuilder.getModelYaml(),
 | 
				
			||||||
 | 
					        modelBuilder.getWeightsArchive(),
 | 
				
			||||||
 | 
					        modelBuilder.getWeightsRoot(),
 | 
				
			||||||
 | 
					        modelBuilder.getTrainingJson(),
 | 
				
			||||||
 | 
					        modelBuilder.getTrainingArchive(),
 | 
				
			||||||
 | 
					        modelBuilder.isEnforceTrainingConfig(),
 | 
				
			||||||
 | 
					        modelBuilder.getInputShape());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
 | 
				
			||||||
 | 
					   * training configuration (JSON), weights, and "training mode" boolean indicator. When built in
 | 
				
			||||||
 | 
					   * training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
 | 
				
			||||||
 | 
					   * Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
 | 
				
			||||||
 | 
					   * otherwise ignored.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param modelJson model configuration JSON string
 | 
				
			||||||
 | 
					   * @param modelYaml model configuration YAML string
 | 
				
			||||||
 | 
					   * @param trainingJson training configuration JSON string
 | 
				
			||||||
 | 
					   * @throws IOException I/O exception
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public KerasSequentialModel(
 | 
				
			||||||
 | 
					      String modelJson,
 | 
				
			||||||
 | 
					      String modelYaml,
 | 
				
			||||||
 | 
					      Hdf5Archive weightsArchive,
 | 
				
			||||||
 | 
					      String weightsRoot,
 | 
				
			||||||
 | 
					      String trainingJson,
 | 
				
			||||||
 | 
					      Hdf5Archive trainingArchive,
 | 
				
			||||||
 | 
					      boolean enforceTrainingConfig,
 | 
				
			||||||
 | 
					      int[] inputShape)
 | 
				
			||||||
 | 
					      throws IOException,
 | 
				
			||||||
 | 
					          InvalidKerasConfigurationException,
 | 
				
			||||||
 | 
					          UnsupportedKerasConfigurationException {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
 | 
				
			||||||
 | 
					    this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
 | 
				
			||||||
 | 
					    this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
 | 
				
			||||||
 | 
					    this.enforceTrainingConfig = enforceTrainingConfig;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /* Determine model configuration type. */
 | 
				
			||||||
 | 
					    if (!modelConfig.containsKey(config.getFieldClassName()))
 | 
				
			||||||
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
 | 
					          "Could not determine Keras model class (no "
 | 
				
			||||||
 | 
					              + config.getFieldClassName()
 | 
				
			||||||
 | 
					              + " field found)");
 | 
				
			||||||
 | 
					    this.className = (String) modelConfig.get(config.getFieldClassName());
 | 
				
			||||||
 | 
					    if (!this.className.equals(config.getFieldClassNameSequential()))
 | 
				
			||||||
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
 | 
					          "Model class name must be "
 | 
				
			||||||
 | 
					              + config.getFieldClassNameSequential()
 | 
				
			||||||
 | 
					              + " (found "
 | 
				
			||||||
 | 
					              + this.className
 | 
				
			||||||
 | 
					              + ")");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /* Process layer configurations. */
 | 
				
			||||||
 | 
					    if (!modelConfig.containsKey(config.getModelFieldConfig()))
 | 
				
			||||||
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
 | 
					          "Could not find layer configurations (no "
 | 
				
			||||||
 | 
					              + config.getModelFieldConfig()
 | 
				
			||||||
 | 
					              + " field found)");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
 | 
				
			||||||
 | 
					    // For consistency
 | 
				
			||||||
 | 
					    // "config" is now an object containing a "name" and "layers", the latter contain the same data
 | 
				
			||||||
 | 
					    // as before.
 | 
				
			||||||
 | 
					    // This change only affects Sequential models.
 | 
				
			||||||
 | 
					    List<Object> layerList;
 | 
				
			||||||
 | 
					    try {
 | 
				
			||||||
 | 
					      layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
 | 
				
			||||||
 | 
					    } catch (Exception e) {
 | 
				
			||||||
 | 
					      HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
 | 
				
			||||||
 | 
					      layerList = (List<Object>) layerMap.get("layers");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
 | 
				
			||||||
     * (Not recommended) Constructor for Sequential model from model configuration
 | 
					    this.layers = layerPair.getFirst();
 | 
				
			||||||
     * (JSON or YAML), training configuration (JSON), weights, and "training mode"
 | 
					    this.layersOrdered = layerPair.getSecond();
 | 
				
			||||||
     * boolean indicator. When built in training mode, certain unsupported configurations
 | 
					 | 
				
			||||||
     * (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
 | 
					 | 
				
			||||||
     * will generate warnings but will be otherwise ignored.
 | 
					 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @param modelJson    model configuration JSON string
 | 
					 | 
				
			||||||
     * @param modelYaml    model configuration YAML string
 | 
					 | 
				
			||||||
     * @param trainingJson training configuration JSON string
 | 
					 | 
				
			||||||
     * @throws IOException I/O exception
 | 
					 | 
				
			||||||
     */
 | 
					 | 
				
			||||||
    public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
 | 
					 | 
				
			||||||
                                String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
 | 
					 | 
				
			||||||
                                int[] inputShape)
 | 
					 | 
				
			||||||
            throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
 | 
					    KerasLayer inputLayer;
 | 
				
			||||||
        this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
 | 
					    if (this.layersOrdered.get(0) instanceof KerasInput) {
 | 
				
			||||||
        this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
 | 
					      inputLayer = this.layersOrdered.get(0);
 | 
				
			||||||
        this.enforceTrainingConfig = enforceTrainingConfig;
 | 
					    } else {
 | 
				
			||||||
 | 
					      /* Add placeholder input layer and update lists of input and output layers. */
 | 
				
			||||||
 | 
					      int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
 | 
				
			||||||
 | 
					      Preconditions.checkState(
 | 
				
			||||||
 | 
					          ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
 | 
				
			||||||
 | 
					      inputLayer = new KerasInput("input1", firstLayerInputShape);
 | 
				
			||||||
 | 
					      inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
 | 
				
			||||||
 | 
					      this.layers.put(inputLayer.getName(), inputLayer);
 | 
				
			||||||
 | 
					      this.layersOrdered.add(0, inputLayer);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
 | 
				
			||||||
 | 
					    this.outputLayerNames =
 | 
				
			||||||
 | 
					        new ArrayList<>(
 | 
				
			||||||
 | 
					            Collections.singletonList(
 | 
				
			||||||
 | 
					                this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /* Determine model configuration type. */
 | 
					    /* Update each layer's inbound layer list to include (only) previous layer. */
 | 
				
			||||||
        if (!modelConfig.containsKey(config.getFieldClassName()))
 | 
					    KerasLayer prevLayer = null;
 | 
				
			||||||
            throw new InvalidKerasConfigurationException(
 | 
					    for (KerasLayer layer : this.layersOrdered) {
 | 
				
			||||||
                    "Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
 | 
					      if (prevLayer != null)
 | 
				
			||||||
        this.className = (String) modelConfig.get(config.getFieldClassName());
 | 
					        layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
 | 
				
			||||||
        if (!this.className.equals(config.getFieldClassNameSequential()))
 | 
					      prevLayer = layer;
 | 
				
			||||||
            throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
 | 
					 | 
				
			||||||
                    + " (found " + this.className + ")");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        /* Process layer configurations. */
 | 
					 | 
				
			||||||
        if (!modelConfig.containsKey(config.getModelFieldConfig()))
 | 
					 | 
				
			||||||
            throw new InvalidKerasConfigurationException(
 | 
					 | 
				
			||||||
                    "Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
 | 
					 | 
				
			||||||
        // "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
 | 
					 | 
				
			||||||
        // This change only affects Sequential models.
 | 
					 | 
				
			||||||
        List<Object> layerList;
 | 
					 | 
				
			||||||
        try {
 | 
					 | 
				
			||||||
            layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
 | 
					 | 
				
			||||||
        } catch (Exception e) {
 | 
					 | 
				
			||||||
            HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
 | 
					 | 
				
			||||||
            layerList =  (List<Object>) layerMap.get("layers");
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
 | 
					 | 
				
			||||||
                prepareLayers(layerList);
 | 
					 | 
				
			||||||
        this.layers = layerPair.getFirst();
 | 
					 | 
				
			||||||
        this.layersOrdered = layerPair.getSecond();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        KerasLayer inputLayer;
 | 
					 | 
				
			||||||
        if (this.layersOrdered.get(0) instanceof KerasInput) {
 | 
					 | 
				
			||||||
            inputLayer = this.layersOrdered.get(0);
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            /* Add placeholder input layer and update lists of input and output layers. */
 | 
					 | 
				
			||||||
            int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
 | 
					 | 
				
			||||||
            Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
 | 
					 | 
				
			||||||
            inputLayer = new KerasInput("input1", firstLayerInputShape);
 | 
					 | 
				
			||||||
            inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
 | 
					 | 
				
			||||||
            this.layers.put(inputLayer.getName(), inputLayer);
 | 
					 | 
				
			||||||
            this.layersOrdered.add(0, inputLayer);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
 | 
					 | 
				
			||||||
        this.outputLayerNames = new ArrayList<>(
 | 
					 | 
				
			||||||
                Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        /* Update each layer's inbound layer list to include (only) previous layer. */
 | 
					 | 
				
			||||||
        KerasLayer prevLayer = null;
 | 
					 | 
				
			||||||
        for (KerasLayer layer : this.layersOrdered) {
 | 
					 | 
				
			||||||
            if (prevLayer != null)
 | 
					 | 
				
			||||||
                layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
 | 
					 | 
				
			||||||
            prevLayer = layer;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        /* Import training configuration. */
 | 
					 | 
				
			||||||
        if (enforceTrainingConfig) {
 | 
					 | 
				
			||||||
            if (trainingJson != null)
 | 
					 | 
				
			||||||
                importTrainingConfiguration(trainingJson);
 | 
					 | 
				
			||||||
            else log.warn("If enforceTrainingConfig is true, a training " +
 | 
					 | 
				
			||||||
                    "configuration object has to be provided. Usually the only practical way to do this is to store" +
 | 
					 | 
				
			||||||
                    " your keras model with `model.save('model_path.h5'. If you store model config and weights" +
 | 
					 | 
				
			||||||
                    " separately no training configuration is attached.");
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        this.outputTypes = inferOutputTypes(inputShape);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (weightsArchive != null)
 | 
					 | 
				
			||||||
            importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /* Import training configuration. */
 | 
				
			||||||
     * Default constructor
 | 
					    if (enforceTrainingConfig) {
 | 
				
			||||||
     */
 | 
					      if (trainingJson != null) importTrainingConfiguration(trainingJson);
 | 
				
			||||||
    public KerasSequentialModel() {
 | 
					      else
 | 
				
			||||||
        super();
 | 
					        log.warn(
 | 
				
			||||||
 | 
					            "If enforceTrainingConfig is true, a training "
 | 
				
			||||||
 | 
					                + "configuration object has to be provided. Usually the only practical way to do this is to store"
 | 
				
			||||||
 | 
					                + " your keras model with `model.save('model_path.h5'. If you store model config and weights"
 | 
				
			||||||
 | 
					                + " separately no training configuration is attached.");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    this.outputTypes = inferOutputTypes(inputShape);
 | 
				
			||||||
     * Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
 | 
					 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @return NeuralNetConfiguration
 | 
					 | 
				
			||||||
     */
 | 
					 | 
				
			||||||
    public NeuralNetConfiguration getNeuralNetConfiguration()
 | 
					 | 
				
			||||||
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
					 | 
				
			||||||
        if (!this.className.equals(config.getFieldClassNameSequential()))
 | 
					 | 
				
			||||||
            throw new InvalidKerasConfigurationException(
 | 
					 | 
				
			||||||
                    "Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
 | 
					 | 
				
			||||||
        if (this.inputLayerNames.size() != 1)
 | 
					 | 
				
			||||||
            throw new InvalidKerasConfigurationException(
 | 
					 | 
				
			||||||
                    "MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
 | 
					 | 
				
			||||||
        if (this.outputLayerNames.size() != 1)
 | 
					 | 
				
			||||||
            throw new InvalidKerasConfigurationException(
 | 
					 | 
				
			||||||
                    "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
 | 
					    if (weightsArchive != null)
 | 
				
			||||||
 | 
					      importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (optimizer != null) {
 | 
					  /** Default constructor */
 | 
				
			||||||
            modelBuilder.updater(optimizer);
 | 
					  public KerasSequentialModel() {
 | 
				
			||||||
 | 
					    super();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @return NeuralNetConfiguration
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public NeuralNetConfiguration getNeuralNetConfiguration()
 | 
				
			||||||
 | 
					      throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
				
			||||||
 | 
					    if (!this.className.equals(config.getFieldClassNameSequential()))
 | 
				
			||||||
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
 | 
					          "Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
 | 
				
			||||||
 | 
					    if (this.inputLayerNames.size() != 1)
 | 
				
			||||||
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
 | 
					          "MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
 | 
				
			||||||
 | 
					    if (this.outputLayerNames.size() != 1)
 | 
				
			||||||
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
 | 
					          "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
 | 
				
			||||||
 | 
					        NeuralNetConfiguration.builder();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (optimizer != null) {
 | 
				
			||||||
 | 
					      modelBuilder.updater(optimizer);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // don't forcibly override for keras import
 | 
				
			||||||
 | 
					    modelBuilder.overrideNinUponBuild(false);
 | 
				
			||||||
 | 
					    /* Add layers one at a time. */
 | 
				
			||||||
 | 
					    KerasLayer prevLayer = null;
 | 
				
			||||||
 | 
					    int layerIndex = 0;
 | 
				
			||||||
 | 
					    for (KerasLayer layer : this.layersOrdered) {
 | 
				
			||||||
 | 
					      if (layer.isLayer()) {
 | 
				
			||||||
 | 
					        int nbInbound = layer.getInboundLayerNames().size();
 | 
				
			||||||
 | 
					        if (nbInbound != 1)
 | 
				
			||||||
 | 
					          throw new InvalidKerasConfigurationException(
 | 
				
			||||||
 | 
					              "Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
 | 
				
			||||||
 | 
					                  + nbInbound
 | 
				
			||||||
 | 
					                  + " for layer "
 | 
				
			||||||
 | 
					                  + layer.getName()
 | 
				
			||||||
 | 
					                  + ")");
 | 
				
			||||||
 | 
					        if (prevLayer != null) {
 | 
				
			||||||
 | 
					          InputType[] inputTypes = new InputType[1];
 | 
				
			||||||
 | 
					          InputPreProcessor preprocessor;
 | 
				
			||||||
 | 
					          if (prevLayer.isInputPreProcessor()) {
 | 
				
			||||||
 | 
					            inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
 | 
				
			||||||
 | 
					            preprocessor = prevLayer.getInputPreprocessor(inputTypes);
 | 
				
			||||||
 | 
					            InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
				
			||||||
 | 
					            layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
 | 
				
			||||||
 | 
					          } else {
 | 
				
			||||||
 | 
					            inputTypes[0] = this.outputTypes.get(prevLayer.getName());
 | 
				
			||||||
 | 
					            preprocessor = layer.getInputPreprocessor(inputTypes);
 | 
				
			||||||
 | 
					            if (preprocessor != null) {
 | 
				
			||||||
 | 
					              InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
				
			||||||
 | 
					              layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
 | 
				
			||||||
 | 
					            } else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          if (preprocessor != null) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            Map<Integer, InputPreProcessor> map = new HashMap<>();
 | 
				
			||||||
 | 
					            map.put(layerIndex, preprocessor);
 | 
				
			||||||
 | 
					            modelBuilder.inputPreProcessors(map);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        modelBuilder.layer(layerIndex++, layer.getLayer());
 | 
				
			||||||
        //don't forcibly override for keras import
 | 
					      } else if (layer.getVertex() != null)
 | 
				
			||||||
        modelBuilder.overrideNinUponBuild(false);
 | 
					        throw new InvalidKerasConfigurationException(
 | 
				
			||||||
        /* Add layers one at a time. */
 | 
					            "Cannot add vertex to NeuralNetConfiguration (class name "
 | 
				
			||||||
        KerasLayer prevLayer = null;
 | 
					                + layer.getClassName()
 | 
				
			||||||
        int layerIndex = 0;
 | 
					                + ", layer name "
 | 
				
			||||||
        for (KerasLayer layer : this.layersOrdered) {
 | 
					                + layer.getName()
 | 
				
			||||||
            if (layer.isLayer()) {
 | 
					                + ")");
 | 
				
			||||||
                int nbInbound = layer.getInboundLayerNames().size();
 | 
					      prevLayer = layer;
 | 
				
			||||||
                if (nbInbound != 1)
 | 
					 | 
				
			||||||
                    throw new InvalidKerasConfigurationException(
 | 
					 | 
				
			||||||
                            "Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
 | 
					 | 
				
			||||||
                                    + nbInbound + " for layer " + layer.getName() + ")");
 | 
					 | 
				
			||||||
                if (prevLayer != null) {
 | 
					 | 
				
			||||||
                    InputType[] inputTypes = new InputType[1];
 | 
					 | 
				
			||||||
                    InputPreProcessor preprocessor;
 | 
					 | 
				
			||||||
                    if (prevLayer.isInputPreProcessor()) {
 | 
					 | 
				
			||||||
                        inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
 | 
					 | 
				
			||||||
                        preprocessor = prevLayer.getInputPreprocessor(inputTypes);
 | 
					 | 
				
			||||||
                        InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
					 | 
				
			||||||
                        layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
 | 
					 | 
				
			||||||
                    } else {
 | 
					 | 
				
			||||||
                        inputTypes[0] = this.outputTypes.get(prevLayer.getName());
 | 
					 | 
				
			||||||
                        preprocessor = layer.getInputPreprocessor(inputTypes);
 | 
					 | 
				
			||||||
                        if(preprocessor != null) {
 | 
					 | 
				
			||||||
                            InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
					 | 
				
			||||||
                            layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        else
 | 
					 | 
				
			||||||
                            layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                    if (preprocessor != null)
 | 
					 | 
				
			||||||
                        modelBuilder.inputPreProcessor(layerIndex, preprocessor);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                modelBuilder.layer(layerIndex++, layer.getLayer());
 | 
					 | 
				
			||||||
            } else if (layer.getVertex() != null)
 | 
					 | 
				
			||||||
                throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
 | 
					 | 
				
			||||||
                        + layer.getClassName() + ", layer name " + layer.getName() + ")");
 | 
					 | 
				
			||||||
            prevLayer = layer;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
 | 
					 | 
				
			||||||
        if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
 | 
					 | 
				
			||||||
            modelBuilder.backpropType(BackpropType.TruncatedBPTT)
 | 
					 | 
				
			||||||
                .tbpttFwdLength(truncatedBPTT)
 | 
					 | 
				
			||||||
                .tbpttBackLength(truncatedBPTT);
 | 
					 | 
				
			||||||
        else
 | 
					 | 
				
			||||||
            modelBuilder.backpropType(BackpropType.Standard);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        NeuralNetConfiguration build = modelBuilder.build();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return build;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
 | 
				
			||||||
     * Build a MultiLayerNetwork from this Keras Sequential model configuration.
 | 
					    if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
 | 
				
			||||||
     *
 | 
					      modelBuilder
 | 
				
			||||||
     * @return MultiLayerNetwork
 | 
					          .backpropType(BackpropType.TruncatedBPTT)
 | 
				
			||||||
     */
 | 
					          .tbpttFwdLength(truncatedBPTT)
 | 
				
			||||||
    public MultiLayerNetwork getMultiLayerNetwork()
 | 
					          .tbpttBackLength(truncatedBPTT);
 | 
				
			||||||
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
					    else modelBuilder.backpropType(BackpropType.Standard);
 | 
				
			||||||
        return getMultiLayerNetwork(true);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    NeuralNetConfiguration build = modelBuilder.build();
 | 
				
			||||||
     * Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
 | 
					
 | 
				
			||||||
     *
 | 
					    return build;
 | 
				
			||||||
     * @return MultiLayerNetwork
 | 
					  }
 | 
				
			||||||
     */
 | 
					
 | 
				
			||||||
    public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
 | 
					  /**
 | 
				
			||||||
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
					   * Build a MultiLayerNetwork from this Keras Sequential model configuration.
 | 
				
			||||||
        MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
 | 
					   *
 | 
				
			||||||
        model.init();
 | 
					   * @return MultiLayerNetwork
 | 
				
			||||||
        if (importWeights)
 | 
					   */
 | 
				
			||||||
            model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
 | 
					  public MultiLayerNetwork getMultiLayerNetwork()
 | 
				
			||||||
        return model;
 | 
					      throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
				
			||||||
    }
 | 
					    return getMultiLayerNetwork(true);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @return MultiLayerNetwork
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
 | 
				
			||||||
 | 
					      throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
				
			||||||
 | 
					    MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
 | 
				
			||||||
 | 
					    model.init();
 | 
				
			||||||
 | 
					    if (importWeights)
 | 
				
			||||||
 | 
					      model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
 | 
				
			||||||
 | 
					    return model;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * A fluent API to configure and create artificial neural networks
 | 
					 * A fluent API to configure and create artificial neural networks
 | 
				
			||||||
@ -30,9 +31,11 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationB
 | 
				
			|||||||
public class NN {
 | 
					public class NN {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public static NeuralNetConfigurationBuilder<?, ?> net() {
 | 
					  public static NeuralNetConfigurationBuilder<?, ?> nn() {
 | 
				
			||||||
    return NeuralNetConfiguration.builder();
 | 
					    return NeuralNetConfiguration.builder();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -152,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
  @Getter @Setter @NonNull @lombok.Builder.Default
 | 
					  @Getter @Setter @NonNull @lombok.Builder.Default
 | 
				
			||||||
  protected BackpropType backpropType = BackpropType.Standard;
 | 
					  protected BackpropType backpropType = BackpropType.Standard;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Getter @lombok.Builder.Default
 | 
					  @Getter @Setter @Singular
 | 
				
			||||||
  protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
 | 
					  protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
 | 
					   * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
 | 
				
			||||||
@ -524,12 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     * @param processor what to use to preProcess the data.
 | 
					     * @param processor what to use to preProcess the data.
 | 
				
			||||||
     * @return builder pattern
 | 
					     * @return builder pattern
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
 | 
					   //public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
 | 
				
			||||||
      if(inputPreProcessors$value==null) inputPreProcessors$value=new LinkedHashMap<>();
 | 
					   //   inputPreProcessors$value.put(layer, processor);
 | 
				
			||||||
      inputPreProcessors$value.put(layer, processor);
 | 
					   //   inputPreProcessors$set = true;
 | 
				
			||||||
      inputPreProcessors$set = true;
 | 
					   //   return self();
 | 
				
			||||||
      return self();
 | 
					   // }
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Set layer at index
 | 
					     * Set layer at index
 | 
				
			||||||
 | 
				
			|||||||
@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 | 
				
			|||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import com.fasterxml.jackson.databind.*;
 | 
					import com.fasterxml.jackson.databind.*;
 | 
				
			||||||
import java.util.*;
 | 
					import java.util.*;
 | 
				
			||||||
 | 
					import java.util.concurrent.atomic.AtomicInteger;
 | 
				
			||||||
import java.util.stream.Collectors;
 | 
					import java.util.stream.Collectors;
 | 
				
			||||||
import lombok.*;
 | 
					import lombok.*;
 | 
				
			||||||
import lombok.experimental.SuperBuilder;
 | 
					import lombok.experimental.SuperBuilder;
 | 
				
			||||||
@ -317,6 +318,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
          @NonNull
 | 
					          @NonNull
 | 
				
			||||||
          InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
 | 
					          InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
 | 
				
			||||||
          if (inputPreProcessor != null) {
 | 
					          if (inputPreProcessor != null) {
 | 
				
			||||||
 | 
					            inputPreProcessors = new HashMap<>(inputPreProcessors);
 | 
				
			||||||
            inputPreProcessors.put(i, inputPreProcessor);
 | 
					            inputPreProcessors.put(i, inputPreProcessor);
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -538,6 +540,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
                    obj.getClass().getSimpleName());
 | 
					                    obj.getClass().getSimpleName());
 | 
				
			||||||
              }
 | 
					              }
 | 
				
			||||||
            });
 | 
					            });
 | 
				
			||||||
 | 
					    // make sure the indexes are sequenced properly
 | 
				
			||||||
 | 
					    AtomicInteger i = new AtomicInteger();
 | 
				
			||||||
 | 
					    ret.forEach(obj -> {
 | 
				
			||||||
 | 
					      obj.setIndex(i.getAndIncrement());
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
    return ret;
 | 
					    return ret;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -219,7 +219,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
				
			|||||||
      throw new IllegalStateException(
 | 
					      throw new IllegalStateException(
 | 
				
			||||||
          "Invalid input for Convolution layer (layer name=\""
 | 
					          "Invalid input for Convolution layer (layer name=\""
 | 
				
			||||||
              + getName()
 | 
					              + getName()
 | 
				
			||||||
              + "\"): Expected CNN input, got "
 | 
					              + "\" at index '"+getIndex()+"') : Expected CNN input, got "
 | 
				
			||||||
              + inputType);
 | 
					              + inputType);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -372,7 +372,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
				
			|||||||
     * @param kernelSize kernel size
 | 
					     * @param kernelSize kernel size
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B kernelSize(int... kernelSize) {
 | 
					    public B kernelSize(int... kernelSize) {
 | 
				
			||||||
      this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
 | 
					      //this.kernelSize$value = ValidationUtils.validate3NonNegative(kernelSize, "kernelSize");
 | 
				
			||||||
 | 
					      this.kernelSize$value = kernelSize;
 | 
				
			||||||
      this.kernelSize$set = true;
 | 
					      this.kernelSize$set = true;
 | 
				
			||||||
      return self();
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -383,7 +384,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
				
			|||||||
     * @param stride kernel size
 | 
					     * @param stride kernel size
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B stride(int... stride) {
 | 
					    public B stride(int... stride) {
 | 
				
			||||||
      this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
 | 
					      //this.stride$value = ValidationUtils.validate3NonNegative(stride, "stride");
 | 
				
			||||||
 | 
					      this.stride$value = stride;
 | 
				
			||||||
      this.stride$set = true;
 | 
					      this.stride$set = true;
 | 
				
			||||||
      return self();
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -394,7 +396,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
				
			|||||||
     * @param padding kernel size
 | 
					     * @param padding kernel size
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B padding(int... padding) {
 | 
					    public B padding(int... padding) {
 | 
				
			||||||
      this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
 | 
					      //this.padding$value = ValidationUtils.validate3NonNegative(padding, "padding");
 | 
				
			||||||
 | 
					      this.padding$value = padding;
 | 
				
			||||||
      this.padding$set = true;
 | 
					      this.padding$set = true;
 | 
				
			||||||
      return self();
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -404,7 +407,8 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
				
			|||||||
     * @param dilation kernel size
 | 
					     * @param dilation kernel size
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B dilation(int... dilation) {
 | 
					    public B dilation(int... dilation) {
 | 
				
			||||||
      this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
 | 
					      //this.dilation$value = ValidationUtils.validate3NonNegative(dilation, "dilation");
 | 
				
			||||||
 | 
					      this.dilation$value = dilation;
 | 
				
			||||||
      this.dilation$set = true;
 | 
					      this.dilation$set = true;
 | 
				
			||||||
      return self();
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -20,14 +20,19 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.conf.layers;
 | 
					package org.deeplearning4j.nn.conf.layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.Collection;
 | 
					import java.util.Collection;
 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					import java.util.stream.IntStream;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.*;
 | 
					import lombok.*;
 | 
				
			||||||
import lombok.experimental.SuperBuilder;
 | 
					import lombok.experimental.SuperBuilder;
 | 
				
			||||||
import lombok.extern.jackson.Jacksonized;
 | 
					import lombok.extern.jackson.Jacksonized;
 | 
				
			||||||
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
import org.deeplearning4j.nn.api.ParamInitializer;
 | 
					import org.deeplearning4j.nn.api.ParamInitializer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
					import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
 | 
					import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
 | 
				
			||||||
@ -84,6 +89,8 @@ public class Deconvolution2D extends ConvolutionLayer {
 | 
				
			|||||||
      boolean initializeParams,
 | 
					      boolean initializeParams,
 | 
				
			||||||
      DataType networkDataType) {
 | 
					      DataType networkDataType) {
 | 
				
			||||||
    setNetConfiguration(conf);
 | 
					    setNetConfiguration(conf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
 | 
					    LayerValidation.assertNInNOutSet("Deconvolution2D", getName(), layerIndex, getNIn(), getNOut());
 | 
				
			||||||
    LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
					    LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
				
			||||||
    runInheritance();
 | 
					    runInheritance();
 | 
				
			||||||
@ -127,11 +134,25 @@ public class Deconvolution2D extends ConvolutionLayer {
 | 
				
			|||||||
        getName(),
 | 
					        getName(),
 | 
				
			||||||
        Deconvolution2DLayer.class);
 | 
					        Deconvolution2DLayer.class);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					@Slf4j
 | 
				
			||||||
  private static final class Deconvolution2DBuilderImpl
 | 
					  private static final class Deconvolution2DBuilderImpl
 | 
				
			||||||
      extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
 | 
					      extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
 | 
				
			||||||
    public Deconvolution2D build() {
 | 
					    public Deconvolution2D build() {
 | 
				
			||||||
      Deconvolution2D l = new Deconvolution2D(this);
 | 
					      Deconvolution2D l = new Deconvolution2D(this);
 | 
				
			||||||
 | 
					      if( l.getConvolutionMode() == ConvolutionMode.Same
 | 
				
			||||||
 | 
					              && IntStream.of(l.getPadding()).sum() != 0) {
 | 
				
			||||||
 | 
					        log.warn("Invalid input for layer '{}'. "
 | 
				
			||||||
 | 
					                + "You cannot have a padding of {} when Convolution Mode is set to 'Same'."
 | 
				
			||||||
 | 
					                + " Padding will be ignored."
 | 
				
			||||||
 | 
					        , l.getName(), l.getPadding());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      /* strides * (input_size-1) + kernel_size - 2*padding */
 | 
				
			||||||
 | 
					      //TODO: This is wrong, also depends on convolutionMode, etc ...
 | 
				
			||||||
 | 
					      /*l.nOut = l.getStride()[0] * (l.getNIn()-1)
 | 
				
			||||||
 | 
					              + IntStream.of(l.getKernelSize()).reduce(1, (a,b) -> a*b)
 | 
				
			||||||
 | 
					              - 2L * IntStream.of(l.getPadding()).sum();
 | 
				
			||||||
 | 
					              */
 | 
				
			||||||
 | 
					      //l.nOut =264;
 | 
				
			||||||
      l.initializeConstraints();
 | 
					      l.initializeConstraints();
 | 
				
			||||||
      return l;
 | 
					      return l;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -62,6 +62,7 @@ public abstract class LayerConfiguration
 | 
				
			|||||||
    implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
 | 
					    implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Getter @Setter protected String name;
 | 
					  @Getter @Setter protected String name;
 | 
				
			||||||
 | 
					  @Getter @Setter private int index;
 | 
				
			||||||
  @Getter @Setter protected List<LayerConstraint> allParamConstraints;
 | 
					  @Getter @Setter protected List<LayerConstraint> allParamConstraints;
 | 
				
			||||||
  @Getter @Setter protected List<LayerConstraint> weightConstraints;
 | 
					  @Getter @Setter protected List<LayerConstraint> weightConstraints;
 | 
				
			||||||
  @Getter @Setter protected List<LayerConstraint> biasConstraints;
 | 
					  @Getter @Setter protected List<LayerConstraint> biasConstraints;
 | 
				
			||||||
@ -72,6 +73,7 @@ public abstract class LayerConfiguration
 | 
				
			|||||||
  /** The type of the layer, basically defines the base class and its properties */
 | 
					  /** The type of the layer, basically defines the base class and its properties */
 | 
				
			||||||
  @Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
 | 
					  @Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Number of parameters this layer has a result of its configuration
 | 
					   * Number of parameters this layer has a result of its configuration
 | 
				
			||||||
   * @return number or parameters
 | 
					   * @return number or parameters
 | 
				
			||||||
@ -80,7 +82,6 @@ public abstract class LayerConfiguration
 | 
				
			|||||||
    return initializer().numParams(this);
 | 
					    return initializer().numParams(this);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * A reference to the neural net configuration. This field is excluded from json serialization as
 | 
					   * A reference to the neural net configuration. This field is excluded from json serialization as
 | 
				
			||||||
   * well as from equals check to avoid circular referenced.
 | 
					   * well as from equals check to avoid circular referenced.
 | 
				
			||||||
 | 
				
			|||||||
@ -37,122 +37,166 @@ import org.nd4j.linalg.api.shape.Shape;
 | 
				
			|||||||
@Data
 | 
					@Data
 | 
				
			||||||
@EqualsAndHashCode(exclude = {"shape"})
 | 
					@EqualsAndHashCode(exclude = {"shape"})
 | 
				
			||||||
public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
 | 
					public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
 | 
				
			||||||
    private long inputHeight;
 | 
					  private long inputHeight;
 | 
				
			||||||
    private long inputWidth;
 | 
					  private long inputWidth;
 | 
				
			||||||
    private long numChannels;
 | 
					  private long numChannels;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Getter(AccessLevel.NONE)
 | 
					  @Getter(AccessLevel.NONE)
 | 
				
			||||||
    @Setter(AccessLevel.NONE)
 | 
					  @Setter(AccessLevel.NONE)
 | 
				
			||||||
    private long[] shape;
 | 
					  private long[] shape;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					  /**
 | 
				
			||||||
     * Reshape to a channels x rows x columns tensor
 | 
					   * Reshape to a channels x rows x columns tensor
 | 
				
			||||||
     *
 | 
					   *
 | 
				
			||||||
     * @param inputHeight the columns
 | 
					   * @param inputHeight the columns
 | 
				
			||||||
     * @param inputWidth  the rows
 | 
					   * @param inputWidth the rows
 | 
				
			||||||
     * @param numChannels the channels
 | 
					   * @param numChannels the channels
 | 
				
			||||||
     */
 | 
					   */
 | 
				
			||||||
    @JsonCreator
 | 
					  @JsonCreator
 | 
				
			||||||
    public FeedForwardToCnnPreProcessor(@JsonProperty("inputHeight") long inputHeight,
 | 
					  public FeedForwardToCnnPreProcessor(
 | 
				
			||||||
                    @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) {
 | 
					      @JsonProperty("inputHeight") long inputHeight,
 | 
				
			||||||
        this.inputHeight = inputHeight;
 | 
					      @JsonProperty("inputWidth") long inputWidth,
 | 
				
			||||||
        this.inputWidth = inputWidth;
 | 
					      @JsonProperty("numChannels") long numChannels) {
 | 
				
			||||||
        this.numChannels = numChannels;
 | 
					    this.inputHeight = inputHeight;
 | 
				
			||||||
 | 
					    this.inputWidth = inputWidth;
 | 
				
			||||||
 | 
					    this.numChannels = numChannels;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Reshape to a channels x rows x columns tensor
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param inputHeight the columns
 | 
				
			||||||
 | 
					   * @param inputWidth the rows
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
 | 
				
			||||||
 | 
					    this.inputHeight = inputHeight;
 | 
				
			||||||
 | 
					    this.inputWidth = inputWidth;
 | 
				
			||||||
 | 
					    this.numChannels = 1;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Override
 | 
				
			||||||
 | 
					  public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
 | 
				
			||||||
 | 
					    this.shape = input.shape();
 | 
				
			||||||
 | 
					    if (input.rank() == 4) return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (input.columns() != inputWidth * inputHeight * numChannels)
 | 
				
			||||||
 | 
					      throw new IllegalArgumentException(
 | 
				
			||||||
 | 
					          "Invalid input: expect output columns must be equal to rows "
 | 
				
			||||||
 | 
					              + inputHeight
 | 
				
			||||||
 | 
					              + " x columns "
 | 
				
			||||||
 | 
					              + inputWidth
 | 
				
			||||||
 | 
					              + " x channels "
 | 
				
			||||||
 | 
					              + numChannels
 | 
				
			||||||
 | 
					              + " but was instead "
 | 
				
			||||||
 | 
					              + Arrays.toString(input.shape()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
 | 
				
			||||||
 | 
					      input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return workspaceMgr.leverageTo(
 | 
				
			||||||
 | 
					        ArrayType.ACTIVATIONS,
 | 
				
			||||||
 | 
					        input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Override
 | 
				
			||||||
 | 
					  // return 4 dimensions
 | 
				
			||||||
 | 
					  public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
 | 
				
			||||||
 | 
					    if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
 | 
				
			||||||
 | 
					      epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
 | 
				
			||||||
 | 
					      if (epsilons.rank() == 2)
 | 
				
			||||||
 | 
					        return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); // should never happen
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public FeedForwardToCnnPreProcessor(long inputWidth, long inputHeight) {
 | 
					    return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
 | 
				
			||||||
        this.inputHeight = inputHeight;
 | 
					  }
 | 
				
			||||||
        this.inputWidth = inputWidth;
 | 
					
 | 
				
			||||||
        this.numChannels = 1;
 | 
					  @Override
 | 
				
			||||||
 | 
					  public FeedForwardToCnnPreProcessor clone() {
 | 
				
			||||||
 | 
					    try {
 | 
				
			||||||
 | 
					      FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
 | 
				
			||||||
 | 
					      if (clone.shape != null) clone.shape = clone.shape.clone();
 | 
				
			||||||
 | 
					      return clone;
 | 
				
			||||||
 | 
					    } catch (CloneNotSupportedException e) {
 | 
				
			||||||
 | 
					      throw new RuntimeException(e);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					  @Override
 | 
				
			||||||
    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
 | 
					  public InputType getOutputType(InputType inputType) {
 | 
				
			||||||
        this.shape = input.shape();
 | 
					 | 
				
			||||||
        if (input.rank() == 4)
 | 
					 | 
				
			||||||
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (input.columns() != inputWidth * inputHeight * numChannels)
 | 
					    switch (inputType.getType()) {
 | 
				
			||||||
            throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows "
 | 
					      case FF:
 | 
				
			||||||
                    + inputHeight + " x columns " + inputWidth + " x channels " + numChannels
 | 
					        InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
 | 
				
			||||||
                    + " but was instead " + Arrays.toString(input.shape()));
 | 
					        val expSize = inputHeight * inputWidth * numChannels;
 | 
				
			||||||
 | 
					        if (c.getSize() != expSize) {
 | 
				
			||||||
        if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
 | 
					          throw new IllegalStateException(
 | 
				
			||||||
            input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
 | 
					              "Invalid input: expected FeedForward input of size "
 | 
				
			||||||
 | 
					                  + expSize
 | 
				
			||||||
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS,
 | 
					                  + " = (d="
 | 
				
			||||||
                input.reshape('c', input.size(0), numChannels, inputHeight, inputWidth));
 | 
					                  + numChannels
 | 
				
			||||||
    }
 | 
					                  + " * w="
 | 
				
			||||||
 | 
					                  + inputWidth
 | 
				
			||||||
    @Override
 | 
					                  + " * h="
 | 
				
			||||||
    // return 4 dimensions
 | 
					                  + inputHeight
 | 
				
			||||||
    public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
 | 
					                  + "), got "
 | 
				
			||||||
        if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilons))
 | 
					                  + inputType);
 | 
				
			||||||
            epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c');
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) {
 | 
					 | 
				
			||||||
            if (epsilons.rank() == 2)
 | 
					 | 
				
			||||||
                return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons); //should never happen
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        return InputType.convolutional(inputHeight, inputWidth, numChannels);
 | 
				
			||||||
 | 
					      case CNN:
 | 
				
			||||||
 | 
					        InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilons.reshape('c', shape));
 | 
					        if (c2.getChannels() != numChannels
 | 
				
			||||||
    }
 | 
					            || c2.getHeight() != inputHeight
 | 
				
			||||||
 | 
					            || c2.getWidth() != inputWidth) {
 | 
				
			||||||
 | 
					          throw new IllegalStateException(
 | 
				
			||||||
    @Override
 | 
					              "Invalid input: Got CNN input type with (d,w,h)=("
 | 
				
			||||||
    public FeedForwardToCnnPreProcessor clone() {
 | 
					                  + c2.getChannels()
 | 
				
			||||||
        try {
 | 
					                  + ","
 | 
				
			||||||
            FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
 | 
					                  + c2.getWidth()
 | 
				
			||||||
            if (clone.shape != null)
 | 
					                  + ","
 | 
				
			||||||
                clone.shape = clone.shape.clone();
 | 
					                  + c2.getHeight()
 | 
				
			||||||
            return clone;
 | 
					                  + ") but expected ("
 | 
				
			||||||
        } catch (CloneNotSupportedException e) {
 | 
					                  + numChannels
 | 
				
			||||||
            throw new RuntimeException(e);
 | 
					                  + ","
 | 
				
			||||||
 | 
					                  + inputHeight
 | 
				
			||||||
 | 
					                  + ","
 | 
				
			||||||
 | 
					                  + inputWidth
 | 
				
			||||||
 | 
					                  + ")");
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					        return c2;
 | 
				
			||||||
 | 
					      case CNNFlat:
 | 
				
			||||||
    @Override
 | 
					        InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
 | 
				
			||||||
    public InputType getOutputType(InputType inputType) {
 | 
					        if (c3.getDepth() != numChannels
 | 
				
			||||||
 | 
					            || c3.getHeight() != inputHeight
 | 
				
			||||||
        switch (inputType.getType()) {
 | 
					            || c3.getWidth() != inputWidth) {
 | 
				
			||||||
            case FF:
 | 
					          throw new IllegalStateException(
 | 
				
			||||||
                InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
 | 
					              "Invalid input: Got CNN input type with (d,w,h)=("
 | 
				
			||||||
                val expSize = inputHeight * inputWidth * numChannels;
 | 
					                  + c3.getDepth()
 | 
				
			||||||
                if (c.getSize() != expSize) {
 | 
					                  + ","
 | 
				
			||||||
                    throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
 | 
					                  + c3.getWidth()
 | 
				
			||||||
                                    + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
 | 
					                  + ","
 | 
				
			||||||
                                    + inputType);
 | 
					                  + c3.getHeight()
 | 
				
			||||||
                }
 | 
					                  + ") but expected ("
 | 
				
			||||||
                return InputType.convolutional(inputHeight, inputWidth, numChannels);
 | 
					                  + numChannels
 | 
				
			||||||
            case CNN:
 | 
					                  + ","
 | 
				
			||||||
                InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;
 | 
					                  + inputHeight
 | 
				
			||||||
 | 
					                  + ","
 | 
				
			||||||
                if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
 | 
					                  + inputWidth
 | 
				
			||||||
                    throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
 | 
					                  + ")");
 | 
				
			||||||
                                    + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
 | 
					 | 
				
			||||||
                                    + "," + inputHeight + "," + inputWidth + ")");
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                return c2;
 | 
					 | 
				
			||||||
            case CNNFlat:
 | 
					 | 
				
			||||||
                InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
 | 
					 | 
				
			||||||
                if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
 | 
					 | 
				
			||||||
                    throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
 | 
					 | 
				
			||||||
                                    + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
 | 
					 | 
				
			||||||
                                    + "," + inputHeight + "," + inputWidth + ")");
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                return c3.getUnflattenedType();
 | 
					 | 
				
			||||||
            default:
 | 
					 | 
				
			||||||
                throw new IllegalStateException("Invalid input type: got " + inputType);
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        return c3.getUnflattenedType();
 | 
				
			||||||
 | 
					      default:
 | 
				
			||||||
 | 
					        throw new IllegalStateException("Invalid input type: got " + inputType);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					  @Override
 | 
				
			||||||
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
 | 
					  public Pair<INDArray, MaskState> feedForwardMaskArray(
 | 
				
			||||||
                    int minibatchSize) {
 | 
					      INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
 | 
				
			||||||
        //Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
 | 
					    // Pass-through, unmodified (assuming here that it's a 1d mask array - one value per example)
 | 
				
			||||||
        return new Pair<>(maskArray, currentMaskState);
 | 
					    return new Pair<>(maskArray, currentMaskState);
 | 
				
			||||||
    }
 | 
					  }
 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -369,7 +369,7 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  protected String layerId() {
 | 
					  protected String layerId() {
 | 
				
			||||||
    String name = this.layerConfiguration.getName();
 | 
					    String name = this.layerConfiguration.getName();
 | 
				
			||||||
    return "(layer name: "
 | 
					    return "(network: " + getNetConfiguration().getName() + " layer name: "
 | 
				
			||||||
        + (name == null ? "\"\"" : name)
 | 
					        + (name == null ? "\"\"" : name)
 | 
				
			||||||
        + ", layer index: "
 | 
					        + ", layer index: "
 | 
				
			||||||
        + index
 | 
					        + index
 | 
				
			||||||
 | 
				
			|||||||
@ -101,8 +101,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        int[] args = new int[] {
 | 
					        int[] args = new int[] {
 | 
				
			||||||
                (int)kH, (int)kW, strides[0], strides[1],
 | 
					                (int)kH, (int)kW, strides[0], strides[1],
 | 
				
			||||||
                pad[0], pad[1], dilation[0], dilation[1], sameMode,
 | 
					                pad[0], pad[1], dilation[0], dilation[1], sameMode //,
 | 
				
			||||||
                nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
					                //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray delta;
 | 
					        INDArray delta;
 | 
				
			||||||
@ -224,8 +224,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        int[] args = new int[] {
 | 
					        int[] args = new int[] {
 | 
				
			||||||
                kH, kW, strides[0], strides[1],
 | 
					                kH, kW, strides[0], strides[1],
 | 
				
			||||||
                pad[0], pad[1], dilation[0], dilation[1], sameMode,
 | 
					                pad[0], pad[1], dilation[0], dilation[1], sameMode //,
 | 
				
			||||||
                nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
					                //nchw ? 0 : 1 //0 = NCHW; 1 = NHWC
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
 | 
					        //DL4J Deconv weights: [inputDepth, outputDepth, kH, kW]
 | 
				
			||||||
@ -238,6 +238,20 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
 | 
				
			|||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            opInputs = new INDArray[]{input, weights};
 | 
					            opInputs = new INDArray[]{input, weights};
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        /**
 | 
				
			||||||
 | 
					         * 2D deconvolution implementation
 | 
				
			||||||
 | 
					         *
 | 
				
			||||||
 | 
					         * IntArgs:
 | 
				
			||||||
 | 
					         * 0: kernel height
 | 
				
			||||||
 | 
					         * 1: kernel width
 | 
				
			||||||
 | 
					         * 2: stride height
 | 
				
			||||||
 | 
					         * 3: stride width
 | 
				
			||||||
 | 
					         * 4: padding height
 | 
				
			||||||
 | 
					         * 5: padding width
 | 
				
			||||||
 | 
					         * 6: dilation height
 | 
				
			||||||
 | 
					         * 7: dilation width
 | 
				
			||||||
 | 
					         * 8: same mode: 0 false, 1 true
 | 
				
			||||||
 | 
					         */
 | 
				
			||||||
        CustomOp op = DynamicCustomOp.builder("deconv2d")
 | 
					        CustomOp op = DynamicCustomOp.builder("deconv2d")
 | 
				
			||||||
                .addInputs(opInputs)
 | 
					                .addInputs(opInputs)
 | 
				
			||||||
                .addIntegerArguments(args)
 | 
					                .addIntegerArguments(args)
 | 
				
			||||||
 | 
				
			|||||||
@ -773,7 +773,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
 | 
				
			|||||||
        LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
 | 
					        LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i);
 | 
				
			||||||
        layers[i] =
 | 
					        layers[i] =
 | 
				
			||||||
            lc.instantiate(
 | 
					            lc.instantiate(
 | 
				
			||||||
                lc.getNetConfiguration(),
 | 
					                this.getNetConfiguration(),
 | 
				
			||||||
                trainingListeners,
 | 
					                trainingListeners,
 | 
				
			||||||
                i,
 | 
					                i,
 | 
				
			||||||
                paramsView,
 | 
					                paramsView,
 | 
				
			||||||
 | 
				
			|||||||
@ -101,8 +101,10 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
 | 
					            params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(GAMMA);
 | 
					            conf.getNetConfiguration().addNetWideVariable(GAMMA);
 | 
				
			||||||
 | 
					            conf.addVariable(GAMMA);
 | 
				
			||||||
            params.put(BETA, createBeta(conf, betaView, initializeParams));
 | 
					            params.put(BETA, createBeta(conf, betaView, initializeParams));
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(BETA);
 | 
					            conf.getNetConfiguration().addNetWideVariable(BETA);
 | 
				
			||||||
 | 
					            conf.addVariable(BETA);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            meanOffset = 2 * nOut;
 | 
					            meanOffset = 2 * nOut;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -125,12 +127,15 @@ public class BatchNormalizationParamInitializer extends AbstractParamInitializer
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        params.put(GLOBAL_MEAN, globalMeanView);
 | 
					        params.put(GLOBAL_MEAN, globalMeanView);
 | 
				
			||||||
        conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
 | 
					        conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
 | 
				
			||||||
 | 
					        conf.addVariable(GLOBAL_MEAN);
 | 
				
			||||||
        if(layer.isUseLogStd()){
 | 
					        if(layer.isUseLogStd()){
 | 
				
			||||||
            params.put(GLOBAL_LOG_STD, globalVarView);
 | 
					            params.put(GLOBAL_LOG_STD, globalVarView);
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
 | 
					            conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
 | 
				
			||||||
 | 
					            conf.addVariable(GLOBAL_LOG_STD);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            params.put(GLOBAL_VAR, globalVarView);
 | 
					            params.put(GLOBAL_VAR, globalVarView);
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
 | 
					            conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
 | 
				
			||||||
 | 
					            conf.addVariable(GLOBAL_VAR);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return params;
 | 
					        return params;
 | 
				
			||||||
 | 
				
			|||||||
@ -114,11 +114,13 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
 | 
				
			|||||||
            params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
 | 
					            params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
 | 
					            conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
 | 
					            conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(BIAS_KEY);
 | 
					            conf.addVariable(WEIGHT_KEY);
 | 
				
			||||||
 | 
					            conf.addVariable(BIAS_KEY);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            INDArray weightView = paramsView;
 | 
					            INDArray weightView = paramsView;
 | 
				
			||||||
            params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
 | 
					            params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
 | 
				
			||||||
            conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
 | 
					            conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY);
 | 
				
			||||||
 | 
					            conf.addVariable(WEIGHT_KEY);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return params;
 | 
					        return params;
 | 
				
			||||||
 | 
				
			|||||||
@ -34,7 +34,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
 | 
					#for whatever reason we had to add MaxMetaspace and file encoding = utf8, gradle crashed otherwise
 | 
				
			||||||
org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
 | 
					org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=768m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 -XX:ErrorFile=/var/log/java/hs_err_pid%p.log
 | 
				
			||||||
 | 
					#-DsocksProxyHost=sshtunnel -DsocksProxyPort=8888 -Djava.net.preferIPv4Stack=true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# When configured, Gradle will run in incubating parallel mode.
 | 
					# When configured, Gradle will run in incubating parallel mode.
 | 
				
			||||||
# This option should only be used with decoupled projects. More details, visit
 | 
					# This option should only be used with decoupled projects. More details, visit
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										170
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										170
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							@ -69,18 +69,18 @@ app_path=$0
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Need this for daisy-chained symlinks.
 | 
					# Need this for daisy-chained symlinks.
 | 
				
			||||||
while
 | 
					while
 | 
				
			||||||
  APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
 | 
					    APP_HOME=${app_path%"${app_path##*/}"}  # leaves a trailing /; empty if no leading path
 | 
				
			||||||
  [ -h "$app_path" ]
 | 
					    [ -h "$app_path" ]
 | 
				
			||||||
do
 | 
					do
 | 
				
			||||||
  ls=$(ls -ld "$app_path")
 | 
					    ls=$( ls -ld "$app_path" )
 | 
				
			||||||
  link=${ls#*' -> '}
 | 
					    link=${ls#*' -> '}
 | 
				
			||||||
  case $link in         #(
 | 
					    case $link in             #(
 | 
				
			||||||
  /*) app_path=$link ;; #(
 | 
					      /*)   app_path=$link ;; #(
 | 
				
			||||||
  *) app_path=$APP_HOME$link ;;
 | 
					      *)    app_path=$APP_HOME$link ;;
 | 
				
			||||||
  esac
 | 
					    esac
 | 
				
			||||||
done
 | 
					done
 | 
				
			||||||
 | 
					
 | 
				
			||||||
APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit
 | 
					APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
APP_NAME="Gradle"
 | 
					APP_NAME="Gradle"
 | 
				
			||||||
APP_BASE_NAME=${0##*/}
 | 
					APP_BASE_NAME=${0##*/}
 | 
				
			||||||
@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
 | 
				
			|||||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
 | 
					# Use the maximum available, or set MAX_FD != -1 to use that value.
 | 
				
			||||||
MAX_FD=maximum
 | 
					MAX_FD=maximum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
warn() {
 | 
					warn () {
 | 
				
			||||||
  echo "$*"
 | 
					    echo "$*"
 | 
				
			||||||
} >&2
 | 
					} >&2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
die() {
 | 
					die () {
 | 
				
			||||||
  echo
 | 
					    echo
 | 
				
			||||||
  echo "$*"
 | 
					    echo "$*"
 | 
				
			||||||
  echo
 | 
					    echo
 | 
				
			||||||
  exit 1
 | 
					    exit 1
 | 
				
			||||||
} >&2
 | 
					} >&2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# OS specific support (must be 'true' or 'false').
 | 
					# OS specific support (must be 'true' or 'false').
 | 
				
			||||||
@ -107,52 +107,51 @@ cygwin=false
 | 
				
			|||||||
msys=false
 | 
					msys=false
 | 
				
			||||||
darwin=false
 | 
					darwin=false
 | 
				
			||||||
nonstop=false
 | 
					nonstop=false
 | 
				
			||||||
case "$(uname)" in           #(
 | 
					case "$( uname )" in                #(
 | 
				
			||||||
CYGWIN*) cygwin=true ;;      #(
 | 
					  CYGWIN* )         cygwin=true  ;; #(
 | 
				
			||||||
Darwin*) darwin=true ;;      #(
 | 
					  Darwin* )         darwin=true  ;; #(
 | 
				
			||||||
MSYS* | MINGW*) msys=true ;; #(
 | 
					  MSYS* | MINGW* )  msys=true    ;; #(
 | 
				
			||||||
NONSTOP*) nonstop=true ;;
 | 
					  NONSTOP* )        nonstop=true ;;
 | 
				
			||||||
esac
 | 
					esac
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
 | 
					CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Determine the Java command to use to start the JVM.
 | 
					# Determine the Java command to use to start the JVM.
 | 
				
			||||||
if [ -n "$JAVA_HOME" ]; then
 | 
					if [ -n "$JAVA_HOME" ] ; then
 | 
				
			||||||
  if [ -x "$JAVA_HOME/jre/sh/java" ]; then
 | 
					    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
 | 
				
			||||||
    # IBM's JDK on AIX uses strange locations for the executables
 | 
					        # IBM's JDK on AIX uses strange locations for the executables
 | 
				
			||||||
    JAVACMD=$JAVA_HOME/jre/sh/java
 | 
					        JAVACMD=$JAVA_HOME/jre/sh/java
 | 
				
			||||||
  else
 | 
					    else
 | 
				
			||||||
    JAVACMD=$JAVA_HOME/bin/java
 | 
					        JAVACMD=$JAVA_HOME/bin/java
 | 
				
			||||||
  fi
 | 
					    fi
 | 
				
			||||||
  if [ ! -x "$JAVACMD" ]; then
 | 
					    if [ ! -x "$JAVACMD" ] ; then
 | 
				
			||||||
    die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
 | 
					        die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Please set the JAVA_HOME variable in your environment to match the
 | 
					Please set the JAVA_HOME variable in your environment to match the
 | 
				
			||||||
location of your Java installation."
 | 
					location of your Java installation."
 | 
				
			||||||
  fi
 | 
					    fi
 | 
				
			||||||
else
 | 
					else
 | 
				
			||||||
  JAVACMD=java
 | 
					    JAVACMD=java
 | 
				
			||||||
  which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
					    which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Please set the JAVA_HOME variable in your environment to match the
 | 
					Please set the JAVA_HOME variable in your environment to match the
 | 
				
			||||||
location of your Java installation."
 | 
					location of your Java installation."
 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Increase the maximum file descriptors if we can.
 | 
					# Increase the maximum file descriptors if we can.
 | 
				
			||||||
if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then
 | 
					if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
 | 
				
			||||||
  case $MAX_FD in #(
 | 
					    case $MAX_FD in #(
 | 
				
			||||||
  max*)
 | 
					      max*)
 | 
				
			||||||
    MAX_FD=$(ulimit -H -n) ||
 | 
					        MAX_FD=$( ulimit -H -n ) ||
 | 
				
			||||||
      warn "Could not query maximum file descriptor limit"
 | 
					            warn "Could not query maximum file descriptor limit"
 | 
				
			||||||
    ;;
 | 
					    esac
 | 
				
			||||||
  esac
 | 
					    case $MAX_FD in  #(
 | 
				
			||||||
  case $MAX_FD in #(
 | 
					      '' | soft) :;; #(
 | 
				
			||||||
  '' | soft) : ;; #(
 | 
					      *)
 | 
				
			||||||
  *)
 | 
					        ulimit -n "$MAX_FD" ||
 | 
				
			||||||
    ulimit -n "$MAX_FD" ||
 | 
					            warn "Could not set maximum file descriptor limit to $MAX_FD"
 | 
				
			||||||
      warn "Could not set maximum file descriptor limit to $MAX_FD"
 | 
					    esac
 | 
				
			||||||
    ;;
 | 
					 | 
				
			||||||
  esac
 | 
					 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Collect all arguments for the java command, stacking in reverse order:
 | 
					# Collect all arguments for the java command, stacking in reverse order:
 | 
				
			||||||
@ -164,36 +163,34 @@ fi
 | 
				
			|||||||
#   * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
 | 
					#   * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# For Cygwin or MSYS, switch paths to Windows format before running java
 | 
					# For Cygwin or MSYS, switch paths to Windows format before running java
 | 
				
			||||||
if "$cygwin" || "$msys"; then
 | 
					if "$cygwin" || "$msys" ; then
 | 
				
			||||||
  APP_HOME=$(cygpath --path --mixed "$APP_HOME")
 | 
					    APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
 | 
				
			||||||
  CLASSPATH=$(cygpath --path --mixed "$CLASSPATH")
 | 
					    CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  JAVACMD=$(cygpath --unix "$JAVACMD")
 | 
					    JAVACMD=$( cygpath --unix "$JAVACMD" )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Now convert the arguments - kludge to limit ourselves to /bin/sh
 | 
					    # Now convert the arguments - kludge to limit ourselves to /bin/sh
 | 
				
			||||||
  for arg; do
 | 
					    for arg do
 | 
				
			||||||
    if
 | 
					        if
 | 
				
			||||||
      case $arg in #(
 | 
					            case $arg in                                #(
 | 
				
			||||||
      -*) false ;; # don't mess with options #(
 | 
					              -*)   false ;;                            # don't mess with options #(
 | 
				
			||||||
      /?*)
 | 
					              /?*)  t=${arg#/} t=/${t%%/*}              # looks like a POSIX filepath
 | 
				
			||||||
        t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
 | 
					                    [ -e "$t" ] ;;                      #(
 | 
				
			||||||
        [ -e "$t" ]
 | 
					              *)    false ;;
 | 
				
			||||||
        ;; #(
 | 
					            esac
 | 
				
			||||||
      *) false ;;
 | 
					        then
 | 
				
			||||||
      esac
 | 
					            arg=$( cygpath --path --ignore --mixed "$arg" )
 | 
				
			||||||
    then
 | 
					        fi
 | 
				
			||||||
      arg=$(cygpath --path --ignore --mixed "$arg")
 | 
					        # Roll the args list around exactly as many times as the number of
 | 
				
			||||||
    fi
 | 
					        # args, so each arg winds up back in the position where it started, but
 | 
				
			||||||
    # Roll the args list around exactly as many times as the number of
 | 
					        # possibly modified.
 | 
				
			||||||
    # args, so each arg winds up back in the position where it started, but
 | 
					        #
 | 
				
			||||||
    # possibly modified.
 | 
					        # NB: a `for` loop captures its iteration list before it begins, so
 | 
				
			||||||
    #
 | 
					        # changing the positional parameters here affects neither the number of
 | 
				
			||||||
    # NB: a `for` loop captures its iteration list before it begins, so
 | 
					        # iterations, nor the values presented in `arg`.
 | 
				
			||||||
    # changing the positional parameters here affects neither the number of
 | 
					        shift                   # remove old arg
 | 
				
			||||||
    # iterations, nor the values presented in `arg`.
 | 
					        set -- "$@" "$arg"      # push replacement arg
 | 
				
			||||||
    shift              # remove old arg
 | 
					    done
 | 
				
			||||||
    set -- "$@" "$arg" # push replacement arg
 | 
					 | 
				
			||||||
  done
 | 
					 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Collect all arguments for the java command;
 | 
					# Collect all arguments for the java command;
 | 
				
			||||||
@ -203,15 +200,10 @@ fi
 | 
				
			|||||||
#   * put everything else in single quotes, so that it's not re-expanded.
 | 
					#   * put everything else in single quotes, so that it's not re-expanded.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
set -- \
 | 
					set -- \
 | 
				
			||||||
  "-Dorg.gradle.appname=$APP_BASE_NAME" \
 | 
					        "-Dorg.gradle.appname=$APP_BASE_NAME" \
 | 
				
			||||||
  -classpath "$CLASSPATH" \
 | 
					        -classpath "$CLASSPATH" \
 | 
				
			||||||
  org.gradle.wrapper.GradleWrapperMain \
 | 
					        org.gradle.wrapper.GradleWrapperMain \
 | 
				
			||||||
  "$@"
 | 
					        "$@"
 | 
				
			||||||
 | 
					 | 
				
			||||||
# Stop when "xargs" is not available.
 | 
					 | 
				
			||||||
if ! command -v xargs >/dev/null 2>&1; then
 | 
					 | 
				
			||||||
  die "xargs is not available"
 | 
					 | 
				
			||||||
fi
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Use "xargs" to parse quoted args.
 | 
					# Use "xargs" to parse quoted args.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
@ -233,10 +225,10 @@ fi
 | 
				
			|||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
eval "set -- $(
 | 
					eval "set -- $(
 | 
				
			||||||
  printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
 | 
					        printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
 | 
				
			||||||
    xargs -n1 |
 | 
					        xargs -n1 |
 | 
				
			||||||
    sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
 | 
					        sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
 | 
				
			||||||
    tr '\n' ' '
 | 
					        tr '\n' ' '
 | 
				
			||||||
)" '"$@"'
 | 
					    )" '"$@"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
exec "$JAVACMD" "$@"
 | 
					exec "$JAVACMD" "$@"
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										14
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							@ -14,7 +14,7 @@
 | 
				
			|||||||
@rem limitations under the License.
 | 
					@rem limitations under the License.
 | 
				
			||||||
@rem
 | 
					@rem
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@if "%DEBUG%"=="" @echo off
 | 
					@if "%DEBUG%" == "" @echo off
 | 
				
			||||||
@rem ##########################################################################
 | 
					@rem ##########################################################################
 | 
				
			||||||
@rem
 | 
					@rem
 | 
				
			||||||
@rem  Gradle startup script for Windows
 | 
					@rem  Gradle startup script for Windows
 | 
				
			||||||
@ -25,7 +25,7 @@
 | 
				
			|||||||
if "%OS%"=="Windows_NT" setlocal
 | 
					if "%OS%"=="Windows_NT" setlocal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
set DIRNAME=%~dp0
 | 
					set DIRNAME=%~dp0
 | 
				
			||||||
if "%DIRNAME%"=="" set DIRNAME=.
 | 
					if "%DIRNAME%" == "" set DIRNAME=.
 | 
				
			||||||
set APP_BASE_NAME=%~n0
 | 
					set APP_BASE_NAME=%~n0
 | 
				
			||||||
set APP_HOME=%DIRNAME%
 | 
					set APP_HOME=%DIRNAME%
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
set JAVA_EXE=java.exe
 | 
					set JAVA_EXE=java.exe
 | 
				
			||||||
%JAVA_EXE% -version >NUL 2>&1
 | 
					%JAVA_EXE% -version >NUL 2>&1
 | 
				
			||||||
if %ERRORLEVEL% equ 0 goto execute
 | 
					if "%ERRORLEVEL%" == "0" goto execute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
echo.
 | 
					echo.
 | 
				
			||||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
					echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
				
			||||||
@ -75,15 +75,13 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
:end
 | 
					:end
 | 
				
			||||||
@rem End local scope for the variables with windows NT shell
 | 
					@rem End local scope for the variables with windows NT shell
 | 
				
			||||||
if %ERRORLEVEL% equ 0 goto mainEnd
 | 
					if "%ERRORLEVEL%"=="0" goto mainEnd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:fail
 | 
					:fail
 | 
				
			||||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
 | 
					rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
 | 
				
			||||||
rem the _cmd.exe /c_ return code!
 | 
					rem the _cmd.exe /c_ return code!
 | 
				
			||||||
set EXIT_CODE=%ERRORLEVEL%
 | 
					if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
 | 
				
			||||||
if %EXIT_CODE% equ 0 set EXIT_CODE=1
 | 
					exit /b 1
 | 
				
			||||||
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
 | 
					 | 
				
			||||||
exit /b %EXIT_CODE%
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
:mainEnd
 | 
					:mainEnd
 | 
				
			||||||
if "%OS%"=="Windows_NT" endlocal
 | 
					if "%OS%"=="Windows_NT" endlocal
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user