Fixed missing imports
Signed-off-by: brian <brian@brutex.de>
This commit is contained in:
		
							parent
							
								
									d0342fc939
								
							
						
					
					
						commit
						e27fb8422f
					
				@ -21,14 +21,14 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package net.brutex.gan;
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.awt.BorderLayout;
 | 
					import java.awt.*;
 | 
				
			||||||
import java.awt.Dimension;
 | 
					 | 
				
			||||||
import java.awt.GridLayout;
 | 
					 | 
				
			||||||
import java.awt.Image;
 | 
					 | 
				
			||||||
import java.awt.image.BufferedImage;
 | 
					import java.awt.image.BufferedImage;
 | 
				
			||||||
import java.io.File;
 | 
					import java.io.File;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
import java.util.Arrays;
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.Random;
 | 
					import java.util.Random;
 | 
				
			||||||
 | 
					import java.util.UUID;
 | 
				
			||||||
 | 
					import javax.imageio.ImageIO;
 | 
				
			||||||
import javax.swing.ImageIcon;
 | 
					import javax.swing.ImageIcon;
 | 
				
			||||||
import javax.swing.JFrame;
 | 
					import javax.swing.JFrame;
 | 
				
			||||||
import javax.swing.JLabel;
 | 
					import javax.swing.JLabel;
 | 
				
			||||||
@ -82,9 +82,9 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  private static final int X_DIM = 20 ;
 | 
					  private static final int X_DIM = 20 ;
 | 
				
			||||||
  private static final int Y_DIM = 20;
 | 
					  private static final int Y_DIM = 20;
 | 
				
			||||||
  private static final int CHANNELS = 3;
 | 
					  private static final int CHANNELS = 1;
 | 
				
			||||||
  private static final int batchSize = 50;
 | 
					  private static final int batchSize = 1;
 | 
				
			||||||
  private static final int INPUT = 128;
 | 
					  private static final int INPUT = 10;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static final int OUTPUT_PER_PANEL = 16;
 | 
					  private static final int OUTPUT_PER_PANEL = 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -96,6 +96,8 @@ public class App {
 | 
				
			|||||||
  private static JPanel panel;
 | 
					  private static JPanel panel;
 | 
				
			||||||
  private static JPanel panel2;
 | 
					  private static JPanel panel2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private static final String OUTPUT_DIR = "C:/temp/output/";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static LayerConfiguration[] genLayers() {
 | 
					  private static LayerConfiguration[] genLayers() {
 | 
				
			||||||
    return new LayerConfiguration[] {
 | 
					    return new LayerConfiguration[] {
 | 
				
			||||||
        DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
 | 
					        DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
 | 
				
			||||||
@ -103,6 +105,7 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
					        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
 | 
					        DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -207,6 +210,12 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
  public void runTest() throws Exception {
 | 
					  public void runTest() throws Exception {
 | 
				
			||||||
 | 
					      if(! log.isDebugEnabled()) {
 | 
				
			||||||
 | 
					          log.info("Logging is not set to DEBUG");
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					      else {
 | 
				
			||||||
 | 
					          log.info("Logging is set to DEBUG");
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    main();
 | 
					    main();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -240,25 +249,25 @@ public class App {
 | 
				
			|||||||
    MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
					    MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
				
			||||||
    gen.init(); log.debug("Generator network: {}", gen);
 | 
					    gen.init(); log.debug("Generator network: {}", gen);
 | 
				
			||||||
    dis.init(); log.debug("Discriminator network: {}", dis);
 | 
					    dis.init(); log.debug("Discriminator network: {}", dis);
 | 
				
			||||||
    gan.init(); log.debug("Complete GAN network: {}", gan);
 | 
					    gan.init(); log.info("Complete GAN network: {}", gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    copyParams(gen, dis, gan);
 | 
					    copyParams(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gen.addTrainingListeners(new PerformanceListener(15, true));
 | 
					    //gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
 | 
				
			||||||
    //dis.addTrainingListeners(new PerformanceListener(10, true));
 | 
					    dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
				
			||||||
    //gan.addTrainingListeners(new PerformanceListener(10, true));
 | 
					    gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
				
			||||||
    //gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
					    //gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
				
			||||||
    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
					    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    System.out.println(gan.toString());
 | 
					    //System.out.println(gan.toString());
 | 
				
			||||||
    gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
					    //gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
 | 
					    //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
 | 
				
			||||||
    //trainData.reset();
 | 
					    //trainData.reset();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int j = 0;
 | 
					    int j = 0;
 | 
				
			||||||
    for (int i = 0; i < 201; i++) { //epoch
 | 
					    for (int i = 0; i < 51; i++) { //epoch
 | 
				
			||||||
      while (trainData.hasNext()) {
 | 
					      while (trainData.hasNext()) {
 | 
				
			||||||
        j++;
 | 
					        j++;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -282,6 +291,9 @@ public class App {
 | 
				
			|||||||
//        int batchSize = (int) real.shape()[0];
 | 
					//        int batchSize = (int) real.shape()[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					        INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
				
			||||||
 | 
					        //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
					        INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
				
			||||||
        fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					        fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -299,11 +311,11 @@ public class App {
 | 
				
			|||||||
        updateGan(gen, dis, gan);
 | 
					        updateGan(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
 | 
					        //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
 | 
				
			||||||
        gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
 | 
					        gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Visualize and reporting
 | 
					        //Visualize and reporting
 | 
				
			||||||
        if (j % 10 == 1) {
 | 
					        if (j % 10 == 1) {
 | 
				
			||||||
          System.out.println("Iteration " + j + " Visualizing...");
 | 
					          System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
 | 
				
			||||||
          INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
 | 
					          INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -330,11 +342,12 @@ public class App {
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
        // Copy the GANs generator to gen.
 | 
					        // Copy the GANs generator to gen.
 | 
				
			||||||
        updateGen(gen, gan);
 | 
					        updateGen(gen, gan);
 | 
				
			||||||
 | 
					        gen.save(new File("mnist-mlp-generator.dlj"));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gen.save(new File("mnist-mlp-generator.dlj"));
 | 
					
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
					  private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
				
			||||||
@ -342,10 +355,10 @@ public class App {
 | 
				
			|||||||
    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
					    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
				
			||||||
      if (i < genLayerCount) {
 | 
					      if (i < genLayerCount) {
 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					        if(gan.getLayer(i).getParams() != null)
 | 
				
			||||||
         gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
					         gan.getLayer(i).setParams(gen.getLayer(i).getParams());
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					        if(gan.getLayer(i).getParams() != null)
 | 
				
			||||||
        dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
					        gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -411,14 +424,50 @@ public class App {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageIcon orig = new ImageIcon(bi);
 | 
					    ImageIcon orig = new ImageIcon(bi);
 | 
				
			||||||
 | 
					 | 
				
			||||||
    Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
					    Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageIcon scaled = new ImageIcon(imageScaled);
 | 
					    ImageIcon scaled = new ImageIcon(imageScaled);
 | 
				
			||||||
 | 
					    if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
				
			||||||
    return new JLabel(scaled);
 | 
					    return new JLabel(scaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
				
			||||||
 | 
					    String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try {
 | 
				
			||||||
 | 
					      // Save the images to disk
 | 
				
			||||||
 | 
					      saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					     log.debug("Images saved successfully.");
 | 
				
			||||||
 | 
					    } catch (IOException e) {
 | 
				
			||||||
 | 
					      log.error("Error saving the images: {}", e.getMessage());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
				
			||||||
 | 
					        File directory = new File(outputDirectory);
 | 
				
			||||||
 | 
					        if (!directory.exists()) {
 | 
				
			||||||
 | 
					            directory.mkdir();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        File outputFile = new File(directory, fileName);
 | 
				
			||||||
 | 
					        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public static BufferedImage imageToBufferedImage(Image image) {
 | 
				
			||||||
 | 
					        if (image instanceof BufferedImage) {
 | 
				
			||||||
 | 
					            return (BufferedImage) image;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Create a buffered image with the same dimensions and transparency as the original image
 | 
				
			||||||
 | 
					        BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Draw the original image onto the buffered image
 | 
				
			||||||
 | 
					        Graphics2D g2d = bufferedImage.createGraphics();
 | 
				
			||||||
 | 
					        g2d.drawImage(image, 0, 0, null);
 | 
				
			||||||
 | 
					        g2d.dispose();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return bufferedImage;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -25,7 +25,7 @@
 | 
				
			|||||||
# Default logging detail level for all instances of SimpleLogger.
 | 
					# Default logging detail level for all instances of SimpleLogger.
 | 
				
			||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
					# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
				
			||||||
# If not specified, defaults to "info".
 | 
					# If not specified, defaults to "info".
 | 
				
			||||||
org.slf4j.simpleLogger.defaultLogLevel=trace
 | 
					org.slf4j.simpleLogger.defaultLogLevel=debug
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Logging detail level for a SimpleLogger instance named "xxxxx".
 | 
					# Logging detail level for a SimpleLogger instance named "xxxxx".
 | 
				
			||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
					# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
				
			||||||
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
 | 
				
			|||||||
# If the format is not specified or is invalid, the default format is used.
 | 
					# If the format is not specified or is invalid, the default format is used.
 | 
				
			||||||
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
 | 
					# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
 | 
				
			||||||
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
 | 
					#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
 | 
				
			||||||
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
 | 
					#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Set to true if you want to output the current thread name.
 | 
					# Set to true if you want to output the current thread name.
 | 
				
			||||||
# Defaults to true.
 | 
					# Defaults to true.
 | 
				
			||||||
org.slf4j.simpleLogger.showThreadName=true
 | 
					#org.slf4j.simpleLogger.showThreadName=true
 | 
				
			||||||
@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,8 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
 | 
				
			|||||||
    private boolean reportEtl = true;
 | 
					    private boolean reportEtl = true;
 | 
				
			||||||
    private boolean reportTime = true;
 | 
					    private boolean reportTime = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private final String name;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public PerformanceListener(int frequency) {
 | 
					    public PerformanceListener(int frequency) {
 | 
				
			||||||
@ -66,14 +68,22 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) {
 | 
					    public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) {
 | 
				
			||||||
 | 
					       this(frequency, reportScore, reportGC, "no-name");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    public PerformanceListener(int frequency, boolean reportScore, boolean reportGC, String name) {
 | 
				
			||||||
        Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency);
 | 
					        Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency);
 | 
				
			||||||
        this.frequency = frequency;
 | 
					        this.frequency = frequency;
 | 
				
			||||||
        this.reportScore = reportScore;
 | 
					        this.reportScore = reportScore;
 | 
				
			||||||
        this.reportGC = reportGC;
 | 
					        this.reportGC = reportGC;
 | 
				
			||||||
 | 
					        this.name = name;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        lastTime.set(System.currentTimeMillis());
 | 
					        lastTime.set(System.currentTimeMillis());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public PerformanceListener(int frequency, boolean reportScore, String name) {
 | 
				
			||||||
 | 
					        this(frequency, reportScore, false, name);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public void iterationDone(IModel model, int iteration, int epoch) {
 | 
					    public void iterationDone(IModel model, int iteration, int epoch) {
 | 
				
			||||||
        // we update lastTime on every iteration
 | 
					        // we update lastTime on every iteration
 | 
				
			||||||
@ -116,6 +126,7 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            StringBuilder builder = new StringBuilder();
 | 
					            StringBuilder builder = new StringBuilder();
 | 
				
			||||||
 | 
					            builder.append("Name: '"+name+"'");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
 | 
					            if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
 | 
				
			||||||
                builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; ");
 | 
					                builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; ");
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										2
									
								
								gradle/wrapper/gradle-wrapper.properties
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								gradle/wrapper/gradle-wrapper.properties
									
									
									
									
										vendored
									
									
								
							@ -1,5 +1,5 @@
 | 
				
			|||||||
distributionBase=GRADLE_USER_HOME
 | 
					distributionBase=GRADLE_USER_HOME
 | 
				
			||||||
distributionPath=wrapper/dists
 | 
					distributionPath=wrapper/dists
 | 
				
			||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip
 | 
					distributionUrl=https\://services.gradle.org/distributions/gradle-8.2.1-bin.zip
 | 
				
			||||||
zipStoreBase=GRADLE_USER_HOME
 | 
					zipStoreBase=GRADLE_USER_HOME
 | 
				
			||||||
zipStorePath=wrapper/dists
 | 
					zipStorePath=wrapper/dists
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										170
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										170
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							@ -69,18 +69,18 @@ app_path=$0
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Need this for daisy-chained symlinks.
 | 
					# Need this for daisy-chained symlinks.
 | 
				
			||||||
while
 | 
					while
 | 
				
			||||||
    APP_HOME=${app_path%"${app_path##*/}"}  # leaves a trailing /; empty if no leading path
 | 
					  APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
 | 
				
			||||||
    [ -h "$app_path" ]
 | 
					  [ -h "$app_path" ]
 | 
				
			||||||
do
 | 
					do
 | 
				
			||||||
    ls=$( ls -ld "$app_path" )
 | 
					  ls=$(ls -ld "$app_path")
 | 
				
			||||||
    link=${ls#*' -> '}
 | 
					  link=${ls#*' -> '}
 | 
				
			||||||
    case $link in             #(
 | 
					  case $link in         #(
 | 
				
			||||||
      /*)   app_path=$link ;; #(
 | 
					  /*) app_path=$link ;; #(
 | 
				
			||||||
      *)    app_path=$APP_HOME$link ;;
 | 
					  *) app_path=$APP_HOME$link ;;
 | 
				
			||||||
    esac
 | 
					  esac
 | 
				
			||||||
done
 | 
					done
 | 
				
			||||||
 | 
					
 | 
				
			||||||
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
 | 
					APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
APP_NAME="Gradle"
 | 
					APP_NAME="Gradle"
 | 
				
			||||||
APP_BASE_NAME=${0##*/}
 | 
					APP_BASE_NAME=${0##*/}
 | 
				
			||||||
@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
 | 
				
			|||||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
 | 
					# Use the maximum available, or set MAX_FD != -1 to use that value.
 | 
				
			||||||
MAX_FD=maximum
 | 
					MAX_FD=maximum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
warn () {
 | 
					warn() {
 | 
				
			||||||
    echo "$*"
 | 
					  echo "$*"
 | 
				
			||||||
} >&2
 | 
					} >&2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
die () {
 | 
					die() {
 | 
				
			||||||
    echo
 | 
					  echo
 | 
				
			||||||
    echo "$*"
 | 
					  echo "$*"
 | 
				
			||||||
    echo
 | 
					  echo
 | 
				
			||||||
    exit 1
 | 
					  exit 1
 | 
				
			||||||
} >&2
 | 
					} >&2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# OS specific support (must be 'true' or 'false').
 | 
					# OS specific support (must be 'true' or 'false').
 | 
				
			||||||
@ -107,51 +107,52 @@ cygwin=false
 | 
				
			|||||||
msys=false
 | 
					msys=false
 | 
				
			||||||
darwin=false
 | 
					darwin=false
 | 
				
			||||||
nonstop=false
 | 
					nonstop=false
 | 
				
			||||||
case "$( uname )" in                #(
 | 
					case "$(uname)" in           #(
 | 
				
			||||||
  CYGWIN* )         cygwin=true  ;; #(
 | 
					CYGWIN*) cygwin=true ;;      #(
 | 
				
			||||||
  Darwin* )         darwin=true  ;; #(
 | 
					Darwin*) darwin=true ;;      #(
 | 
				
			||||||
  MSYS* | MINGW* )  msys=true    ;; #(
 | 
					MSYS* | MINGW*) msys=true ;; #(
 | 
				
			||||||
  NONSTOP* )        nonstop=true ;;
 | 
					NONSTOP*) nonstop=true ;;
 | 
				
			||||||
esac
 | 
					esac
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
 | 
					CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
# Determine the Java command to use to start the JVM.
 | 
					# Determine the Java command to use to start the JVM.
 | 
				
			||||||
if [ -n "$JAVA_HOME" ] ; then
 | 
					if [ -n "$JAVA_HOME" ]; then
 | 
				
			||||||
    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
 | 
					  if [ -x "$JAVA_HOME/jre/sh/java" ]; then
 | 
				
			||||||
        # IBM's JDK on AIX uses strange locations for the executables
 | 
					    # IBM's JDK on AIX uses strange locations for the executables
 | 
				
			||||||
        JAVACMD=$JAVA_HOME/jre/sh/java
 | 
					    JAVACMD=$JAVA_HOME/jre/sh/java
 | 
				
			||||||
    else
 | 
					  else
 | 
				
			||||||
        JAVACMD=$JAVA_HOME/bin/java
 | 
					    JAVACMD=$JAVA_HOME/bin/java
 | 
				
			||||||
    fi
 | 
					  fi
 | 
				
			||||||
    if [ ! -x "$JAVACMD" ] ; then
 | 
					  if [ ! -x "$JAVACMD" ]; then
 | 
				
			||||||
        die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
 | 
					    die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Please set the JAVA_HOME variable in your environment to match the
 | 
					Please set the JAVA_HOME variable in your environment to match the
 | 
				
			||||||
location of your Java installation."
 | 
					location of your Java installation."
 | 
				
			||||||
    fi
 | 
					  fi
 | 
				
			||||||
else
 | 
					else
 | 
				
			||||||
    JAVACMD=java
 | 
					  JAVACMD=java
 | 
				
			||||||
    which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
					  which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Please set the JAVA_HOME variable in your environment to match the
 | 
					Please set the JAVA_HOME variable in your environment to match the
 | 
				
			||||||
location of your Java installation."
 | 
					location of your Java installation."
 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Increase the maximum file descriptors if we can.
 | 
					# Increase the maximum file descriptors if we can.
 | 
				
			||||||
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
 | 
					if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then
 | 
				
			||||||
    case $MAX_FD in #(
 | 
					  case $MAX_FD in #(
 | 
				
			||||||
      max*)
 | 
					  max*)
 | 
				
			||||||
        MAX_FD=$( ulimit -H -n ) ||
 | 
					    MAX_FD=$(ulimit -H -n) ||
 | 
				
			||||||
            warn "Could not query maximum file descriptor limit"
 | 
					      warn "Could not query maximum file descriptor limit"
 | 
				
			||||||
    esac
 | 
					    ;;
 | 
				
			||||||
    case $MAX_FD in  #(
 | 
					  esac
 | 
				
			||||||
      '' | soft) :;; #(
 | 
					  case $MAX_FD in #(
 | 
				
			||||||
      *)
 | 
					  '' | soft) : ;; #(
 | 
				
			||||||
        ulimit -n "$MAX_FD" ||
 | 
					  *)
 | 
				
			||||||
            warn "Could not set maximum file descriptor limit to $MAX_FD"
 | 
					    ulimit -n "$MAX_FD" ||
 | 
				
			||||||
    esac
 | 
					      warn "Could not set maximum file descriptor limit to $MAX_FD"
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
 | 
					  esac
 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Collect all arguments for the java command, stacking in reverse order:
 | 
					# Collect all arguments for the java command, stacking in reverse order:
 | 
				
			||||||
@ -163,34 +164,36 @@ fi
 | 
				
			|||||||
#   * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
 | 
					#   * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# For Cygwin or MSYS, switch paths to Windows format before running java
 | 
					# For Cygwin or MSYS, switch paths to Windows format before running java
 | 
				
			||||||
if "$cygwin" || "$msys" ; then
 | 
					if "$cygwin" || "$msys"; then
 | 
				
			||||||
    APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
 | 
					  APP_HOME=$(cygpath --path --mixed "$APP_HOME")
 | 
				
			||||||
    CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
 | 
					  CLASSPATH=$(cygpath --path --mixed "$CLASSPATH")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    JAVACMD=$( cygpath --unix "$JAVACMD" )
 | 
					  JAVACMD=$(cygpath --unix "$JAVACMD")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Now convert the arguments - kludge to limit ourselves to /bin/sh
 | 
					  # Now convert the arguments - kludge to limit ourselves to /bin/sh
 | 
				
			||||||
    for arg do
 | 
					  for arg; do
 | 
				
			||||||
        if
 | 
					    if
 | 
				
			||||||
            case $arg in                                #(
 | 
					      case $arg in #(
 | 
				
			||||||
              -*)   false ;;                            # don't mess with options #(
 | 
					      -*) false ;; # don't mess with options #(
 | 
				
			||||||
              /?*)  t=${arg#/} t=/${t%%/*}              # looks like a POSIX filepath
 | 
					      /?*)
 | 
				
			||||||
                    [ -e "$t" ] ;;                      #(
 | 
					        t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
 | 
				
			||||||
              *)    false ;;
 | 
					        [ -e "$t" ]
 | 
				
			||||||
            esac
 | 
					        ;; #(
 | 
				
			||||||
        then
 | 
					      *) false ;;
 | 
				
			||||||
            arg=$( cygpath --path --ignore --mixed "$arg" )
 | 
					      esac
 | 
				
			||||||
        fi
 | 
					    then
 | 
				
			||||||
        # Roll the args list around exactly as many times as the number of
 | 
					      arg=$(cygpath --path --ignore --mixed "$arg")
 | 
				
			||||||
        # args, so each arg winds up back in the position where it started, but
 | 
					    fi
 | 
				
			||||||
        # possibly modified.
 | 
					    # Roll the args list around exactly as many times as the number of
 | 
				
			||||||
        #
 | 
					    # args, so each arg winds up back in the position where it started, but
 | 
				
			||||||
        # NB: a `for` loop captures its iteration list before it begins, so
 | 
					    # possibly modified.
 | 
				
			||||||
        # changing the positional parameters here affects neither the number of
 | 
					    #
 | 
				
			||||||
        # iterations, nor the values presented in `arg`.
 | 
					    # NB: a `for` loop captures its iteration list before it begins, so
 | 
				
			||||||
        shift                   # remove old arg
 | 
					    # changing the positional parameters here affects neither the number of
 | 
				
			||||||
        set -- "$@" "$arg"      # push replacement arg
 | 
					    # iterations, nor the values presented in `arg`.
 | 
				
			||||||
    done
 | 
					    shift              # remove old arg
 | 
				
			||||||
 | 
					    set -- "$@" "$arg" # push replacement arg
 | 
				
			||||||
 | 
					  done
 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Collect all arguments for the java command;
 | 
					# Collect all arguments for the java command;
 | 
				
			||||||
@ -200,10 +203,15 @@ fi
 | 
				
			|||||||
#   * put everything else in single quotes, so that it's not re-expanded.
 | 
					#   * put everything else in single quotes, so that it's not re-expanded.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
set -- \
 | 
					set -- \
 | 
				
			||||||
        "-Dorg.gradle.appname=$APP_BASE_NAME" \
 | 
					  "-Dorg.gradle.appname=$APP_BASE_NAME" \
 | 
				
			||||||
        -classpath "$CLASSPATH" \
 | 
					  -classpath "$CLASSPATH" \
 | 
				
			||||||
        org.gradle.wrapper.GradleWrapperMain \
 | 
					  org.gradle.wrapper.GradleWrapperMain \
 | 
				
			||||||
        "$@"
 | 
					  "$@"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Stop when "xargs" is not available.
 | 
				
			||||||
 | 
					if ! command -v xargs >/dev/null 2>&1; then
 | 
				
			||||||
 | 
					  die "xargs is not available"
 | 
				
			||||||
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Use "xargs" to parse quoted args.
 | 
					# Use "xargs" to parse quoted args.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
@ -225,10 +233,10 @@ set -- \
 | 
				
			|||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
eval "set -- $(
 | 
					eval "set -- $(
 | 
				
			||||||
        printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
 | 
					  printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
 | 
				
			||||||
        xargs -n1 |
 | 
					    xargs -n1 |
 | 
				
			||||||
        sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
 | 
					    sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
 | 
				
			||||||
        tr '\n' ' '
 | 
					    tr '\n' ' '
 | 
				
			||||||
    )" '"$@"'
 | 
					)" '"$@"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
exec "$JAVACMD" "$@"
 | 
					exec "$JAVACMD" "$@"
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										14
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							@ -14,7 +14,7 @@
 | 
				
			|||||||
@rem limitations under the License.
 | 
					@rem limitations under the License.
 | 
				
			||||||
@rem
 | 
					@rem
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@if "%DEBUG%" == "" @echo off
 | 
					@if "%DEBUG%"=="" @echo off
 | 
				
			||||||
@rem ##########################################################################
 | 
					@rem ##########################################################################
 | 
				
			||||||
@rem
 | 
					@rem
 | 
				
			||||||
@rem  Gradle startup script for Windows
 | 
					@rem  Gradle startup script for Windows
 | 
				
			||||||
@ -25,7 +25,7 @@
 | 
				
			|||||||
if "%OS%"=="Windows_NT" setlocal
 | 
					if "%OS%"=="Windows_NT" setlocal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
set DIRNAME=%~dp0
 | 
					set DIRNAME=%~dp0
 | 
				
			||||||
if "%DIRNAME%" == "" set DIRNAME=.
 | 
					if "%DIRNAME%"=="" set DIRNAME=.
 | 
				
			||||||
set APP_BASE_NAME=%~n0
 | 
					set APP_BASE_NAME=%~n0
 | 
				
			||||||
set APP_HOME=%DIRNAME%
 | 
					set APP_HOME=%DIRNAME%
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
set JAVA_EXE=java.exe
 | 
					set JAVA_EXE=java.exe
 | 
				
			||||||
%JAVA_EXE% -version >NUL 2>&1
 | 
					%JAVA_EXE% -version >NUL 2>&1
 | 
				
			||||||
if "%ERRORLEVEL%" == "0" goto execute
 | 
					if %ERRORLEVEL% equ 0 goto execute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
echo.
 | 
					echo.
 | 
				
			||||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
					echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
				
			||||||
@ -75,13 +75,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
:end
 | 
					:end
 | 
				
			||||||
@rem End local scope for the variables with windows NT shell
 | 
					@rem End local scope for the variables with windows NT shell
 | 
				
			||||||
if "%ERRORLEVEL%"=="0" goto mainEnd
 | 
					if %ERRORLEVEL% equ 0 goto mainEnd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:fail
 | 
					:fail
 | 
				
			||||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
 | 
					rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
 | 
				
			||||||
rem the _cmd.exe /c_ return code!
 | 
					rem the _cmd.exe /c_ return code!
 | 
				
			||||||
if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
 | 
					set EXIT_CODE=%ERRORLEVEL%
 | 
				
			||||||
exit /b 1
 | 
					if %EXIT_CODE% equ 0 set EXIT_CODE=1
 | 
				
			||||||
 | 
					if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
 | 
				
			||||||
 | 
					exit /b %EXIT_CODE%
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:mainEnd
 | 
					:mainEnd
 | 
				
			||||||
if "%OS%"=="Windows_NT" endlocal
 | 
					if "%OS%"=="Windows_NT" endlocal
 | 
				
			||||||
 | 
				
			|||||||
@ -70,7 +70,7 @@ apply from: "chooseBackend.gradle"
 | 
				
			|||||||
rootProject.name = "Cavis"
 | 
					rootProject.name = "Cavis"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
 | 
					enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
 | 
				
			||||||
enableFeaturePreview("VERSION_CATALOGS")
 | 
					//enableFeaturePreview("VERSION_CATALOGS") //only needed for gradle <8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sourceControl {
 | 
					sourceControl {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user