Fixed missing imports
Signed-off-by: brian <brian@brutex.de>
This commit is contained in:
		
							parent
							
								
									d0342fc939
								
							
						
					
					
						commit
						e27fb8422f
					
				@ -21,14 +21,14 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package net.brutex.gan;
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.awt.BorderLayout;
 | 
					import java.awt.*;
 | 
				
			||||||
import java.awt.Dimension;
 | 
					 | 
				
			||||||
import java.awt.GridLayout;
 | 
					 | 
				
			||||||
import java.awt.Image;
 | 
					 | 
				
			||||||
import java.awt.image.BufferedImage;
 | 
					import java.awt.image.BufferedImage;
 | 
				
			||||||
import java.io.File;
 | 
					import java.io.File;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
import java.util.Arrays;
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.Random;
 | 
					import java.util.Random;
 | 
				
			||||||
 | 
					import java.util.UUID;
 | 
				
			||||||
 | 
					import javax.imageio.ImageIO;
 | 
				
			||||||
import javax.swing.ImageIcon;
 | 
					import javax.swing.ImageIcon;
 | 
				
			||||||
import javax.swing.JFrame;
 | 
					import javax.swing.JFrame;
 | 
				
			||||||
import javax.swing.JLabel;
 | 
					import javax.swing.JLabel;
 | 
				
			||||||
@ -82,9 +82,9 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  private static final int X_DIM = 20 ;
 | 
					  private static final int X_DIM = 20 ;
 | 
				
			||||||
  private static final int Y_DIM = 20;
 | 
					  private static final int Y_DIM = 20;
 | 
				
			||||||
  private static final int CHANNELS = 3;
 | 
					  private static final int CHANNELS = 1;
 | 
				
			||||||
  private static final int batchSize = 50;
 | 
					  private static final int batchSize = 1;
 | 
				
			||||||
  private static final int INPUT = 128;
 | 
					  private static final int INPUT = 10;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static final int OUTPUT_PER_PANEL = 16;
 | 
					  private static final int OUTPUT_PER_PANEL = 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -96,6 +96,8 @@ public class App {
 | 
				
			|||||||
  private static JPanel panel;
 | 
					  private static JPanel panel;
 | 
				
			||||||
  private static JPanel panel2;
 | 
					  private static JPanel panel2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private static final String OUTPUT_DIR = "C:/temp/output/";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static LayerConfiguration[] genLayers() {
 | 
					  private static LayerConfiguration[] genLayers() {
 | 
				
			||||||
    return new LayerConfiguration[] {
 | 
					    return new LayerConfiguration[] {
 | 
				
			||||||
        DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
 | 
					        DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
 | 
				
			||||||
@ -103,6 +105,7 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
					        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
 | 
					        DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -207,6 +210,12 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
  public void runTest() throws Exception {
 | 
					  public void runTest() throws Exception {
 | 
				
			||||||
 | 
					      if(! log.isDebugEnabled()) {
 | 
				
			||||||
 | 
					          log.info("Logging is not set to DEBUG");
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					      else {
 | 
				
			||||||
 | 
					          log.info("Logging is set to DEBUG");
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    main();
 | 
					    main();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -240,25 +249,25 @@ public class App {
 | 
				
			|||||||
    MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
					    MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
				
			||||||
    gen.init(); log.debug("Generator network: {}", gen);
 | 
					    gen.init(); log.debug("Generator network: {}", gen);
 | 
				
			||||||
    dis.init(); log.debug("Discriminator network: {}", dis);
 | 
					    dis.init(); log.debug("Discriminator network: {}", dis);
 | 
				
			||||||
    gan.init(); log.debug("Complete GAN network: {}", gan);
 | 
					    gan.init(); log.info("Complete GAN network: {}", gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    copyParams(gen, dis, gan);
 | 
					    copyParams(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gen.addTrainingListeners(new PerformanceListener(15, true));
 | 
					    //gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
 | 
				
			||||||
    //dis.addTrainingListeners(new PerformanceListener(10, true));
 | 
					    dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
				
			||||||
    //gan.addTrainingListeners(new PerformanceListener(10, true));
 | 
					    gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
				
			||||||
    //gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
					    //gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
				
			||||||
    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
					    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    System.out.println(gan.toString());
 | 
					    //System.out.println(gan.toString());
 | 
				
			||||||
    gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
					    //gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
 | 
					    //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
 | 
				
			||||||
    //trainData.reset();
 | 
					    //trainData.reset();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int j = 0;
 | 
					    int j = 0;
 | 
				
			||||||
    for (int i = 0; i < 201; i++) { //epoch
 | 
					    for (int i = 0; i < 51; i++) { //epoch
 | 
				
			||||||
      while (trainData.hasNext()) {
 | 
					      while (trainData.hasNext()) {
 | 
				
			||||||
        j++;
 | 
					        j++;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -282,6 +291,9 @@ public class App {
 | 
				
			|||||||
//        int batchSize = (int) real.shape()[0];
 | 
					//        int batchSize = (int) real.shape()[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					        INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
				
			||||||
 | 
					        //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
					        INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
				
			||||||
        fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					        fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -299,11 +311,11 @@ public class App {
 | 
				
			|||||||
        updateGan(gen, dis, gan);
 | 
					        updateGan(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
 | 
					        //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
 | 
				
			||||||
        gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
 | 
					        gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Visualize and reporting
 | 
					        //Visualize and reporting
 | 
				
			||||||
        if (j % 10 == 1) {
 | 
					        if (j % 10 == 1) {
 | 
				
			||||||
          System.out.println("Iteration " + j + " Visualizing...");
 | 
					          System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
 | 
				
			||||||
          INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
 | 
					          INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -330,11 +342,12 @@ public class App {
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
        // Copy the GANs generator to gen.
 | 
					        // Copy the GANs generator to gen.
 | 
				
			||||||
        updateGen(gen, gan);
 | 
					        updateGen(gen, gan);
 | 
				
			||||||
 | 
					        gen.save(new File("mnist-mlp-generator.dlj"));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gen.save(new File("mnist-mlp-generator.dlj"));
 | 
					
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
					  private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
				
			||||||
@ -342,10 +355,10 @@ public class App {
 | 
				
			|||||||
    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
					    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
				
			||||||
      if (i < genLayerCount) {
 | 
					      if (i < genLayerCount) {
 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					        if(gan.getLayer(i).getParams() != null)
 | 
				
			||||||
         gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
					         gan.getLayer(i).setParams(gen.getLayer(i).getParams());
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					        if(gan.getLayer(i).getParams() != null)
 | 
				
			||||||
        dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
					        gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -411,14 +424,50 @@ public class App {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageIcon orig = new ImageIcon(bi);
 | 
					    ImageIcon orig = new ImageIcon(bi);
 | 
				
			||||||
 | 
					 | 
				
			||||||
    Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
					    Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageIcon scaled = new ImageIcon(imageScaled);
 | 
					    ImageIcon scaled = new ImageIcon(imageScaled);
 | 
				
			||||||
 | 
					    if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
				
			||||||
    return new JLabel(scaled);
 | 
					    return new JLabel(scaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
				
			||||||
 | 
					    String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try {
 | 
				
			||||||
 | 
					      // Save the images to disk
 | 
				
			||||||
 | 
					      saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					     log.debug("Images saved successfully.");
 | 
				
			||||||
 | 
					    } catch (IOException e) {
 | 
				
			||||||
 | 
					      log.error("Error saving the images: {}", e.getMessage());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
				
			||||||
 | 
					        File directory = new File(outputDirectory);
 | 
				
			||||||
 | 
					        if (!directory.exists()) {
 | 
				
			||||||
 | 
					            directory.mkdir();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        File outputFile = new File(directory, fileName);
 | 
				
			||||||
 | 
					        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public static BufferedImage imageToBufferedImage(Image image) {
 | 
				
			||||||
 | 
					        if (image instanceof BufferedImage) {
 | 
				
			||||||
 | 
					            return (BufferedImage) image;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Create a buffered image with the same dimensions and transparency as the original image
 | 
				
			||||||
 | 
					        BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Draw the original image onto the buffered image
 | 
				
			||||||
 | 
					        Graphics2D g2d = bufferedImage.createGraphics();
 | 
				
			||||||
 | 
					        g2d.drawImage(image, 0, 0, null);
 | 
				
			||||||
 | 
					        g2d.dispose();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return bufferedImage;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -25,7 +25,7 @@
 | 
				
			|||||||
# Default logging detail level for all instances of SimpleLogger.
 | 
					# Default logging detail level for all instances of SimpleLogger.
 | 
				
			||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
					# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
				
			||||||
# If not specified, defaults to "info".
 | 
					# If not specified, defaults to "info".
 | 
				
			||||||
org.slf4j.simpleLogger.defaultLogLevel=trace
 | 
					org.slf4j.simpleLogger.defaultLogLevel=debug
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Logging detail level for a SimpleLogger instance named "xxxxx".
 | 
					# Logging detail level for a SimpleLogger instance named "xxxxx".
 | 
				
			||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
					# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
				
			||||||
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
 | 
				
			|||||||
# If the format is not specified or is invalid, the default format is used.
 | 
					# If the format is not specified or is invalid, the default format is used.
 | 
				
			||||||
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
 | 
					# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
 | 
				
			||||||
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
 | 
					#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
 | 
				
			||||||
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
 | 
					#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Set to true if you want to output the current thread name.
 | 
					# Set to true if you want to output the current thread name.
 | 
				
			||||||
# Defaults to true.
 | 
					# Defaults to true.
 | 
				
			||||||
org.slf4j.simpleLogger.showThreadName=true
 | 
					#org.slf4j.simpleLogger.showThreadName=true
 | 
				
			||||||
@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,8 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
 | 
				
			|||||||
    private boolean reportEtl = true;
 | 
					    private boolean reportEtl = true;
 | 
				
			||||||
    private boolean reportTime = true;
 | 
					    private boolean reportTime = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private final String name;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public PerformanceListener(int frequency) {
 | 
					    public PerformanceListener(int frequency) {
 | 
				
			||||||
@ -66,14 +68,22 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) {
 | 
					    public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) {
 | 
				
			||||||
 | 
					       this(frequency, reportScore, reportGC, "no-name");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    public PerformanceListener(int frequency, boolean reportScore, boolean reportGC, String name) {
 | 
				
			||||||
        Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency);
 | 
					        Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency);
 | 
				
			||||||
        this.frequency = frequency;
 | 
					        this.frequency = frequency;
 | 
				
			||||||
        this.reportScore = reportScore;
 | 
					        this.reportScore = reportScore;
 | 
				
			||||||
        this.reportGC = reportGC;
 | 
					        this.reportGC = reportGC;
 | 
				
			||||||
 | 
					        this.name = name;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        lastTime.set(System.currentTimeMillis());
 | 
					        lastTime.set(System.currentTimeMillis());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public PerformanceListener(int frequency, boolean reportScore, String name) {
 | 
				
			||||||
 | 
					        this(frequency, reportScore, false, name);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public void iterationDone(IModel model, int iteration, int epoch) {
 | 
					    public void iterationDone(IModel model, int iteration, int epoch) {
 | 
				
			||||||
        // we update lastTime on every iteration
 | 
					        // we update lastTime on every iteration
 | 
				
			||||||
@ -116,6 +126,7 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            StringBuilder builder = new StringBuilder();
 | 
					            StringBuilder builder = new StringBuilder();
 | 
				
			||||||
 | 
					            builder.append("Name: '"+name+"'");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
 | 
					            if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
 | 
				
			||||||
                builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; ");
 | 
					                builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; ");
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gradle/wrapper/gradle-wrapper.jar
									
									
									
									
										vendored
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										2
									
								
								gradle/wrapper/gradle-wrapper.properties
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								gradle/wrapper/gradle-wrapper.properties
									
									
									
									
										vendored
									
									
								
							@ -1,5 +1,5 @@
 | 
				
			|||||||
distributionBase=GRADLE_USER_HOME
 | 
					distributionBase=GRADLE_USER_HOME
 | 
				
			||||||
distributionPath=wrapper/dists
 | 
					distributionPath=wrapper/dists
 | 
				
			||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip
 | 
					distributionUrl=https\://services.gradle.org/distributions/gradle-8.2.1-bin.zip
 | 
				
			||||||
zipStoreBase=GRADLE_USER_HOME
 | 
					zipStoreBase=GRADLE_USER_HOME
 | 
				
			||||||
zipStorePath=wrapper/dists
 | 
					zipStorePath=wrapper/dists
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										16
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								gradlew
									
									
									
									
										vendored
									
									
								
							@ -116,7 +116,6 @@ esac
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
 | 
					CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
# Determine the Java command to use to start the JVM.
 | 
					# Determine the Java command to use to start the JVM.
 | 
				
			||||||
if [ -n "$JAVA_HOME" ]; then
 | 
					if [ -n "$JAVA_HOME" ]; then
 | 
				
			||||||
  if [ -x "$JAVA_HOME/jre/sh/java" ]; then
 | 
					  if [ -x "$JAVA_HOME/jre/sh/java" ]; then
 | 
				
			||||||
@ -145,12 +144,14 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
 | 
				
			|||||||
  max*)
 | 
					  max*)
 | 
				
			||||||
    MAX_FD=$(ulimit -H -n) ||
 | 
					    MAX_FD=$(ulimit -H -n) ||
 | 
				
			||||||
      warn "Could not query maximum file descriptor limit"
 | 
					      warn "Could not query maximum file descriptor limit"
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
  esac
 | 
					  esac
 | 
				
			||||||
  case $MAX_FD in #(
 | 
					  case $MAX_FD in #(
 | 
				
			||||||
  '' | soft) : ;; #(
 | 
					  '' | soft) : ;; #(
 | 
				
			||||||
  *)
 | 
					  *)
 | 
				
			||||||
    ulimit -n "$MAX_FD" ||
 | 
					    ulimit -n "$MAX_FD" ||
 | 
				
			||||||
      warn "Could not set maximum file descriptor limit to $MAX_FD"
 | 
					      warn "Could not set maximum file descriptor limit to $MAX_FD"
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
  esac
 | 
					  esac
 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -170,12 +171,14 @@ if "$cygwin" || "$msys" ; then
 | 
				
			|||||||
  JAVACMD=$(cygpath --unix "$JAVACMD")
 | 
					  JAVACMD=$(cygpath --unix "$JAVACMD")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Now convert the arguments - kludge to limit ourselves to /bin/sh
 | 
					  # Now convert the arguments - kludge to limit ourselves to /bin/sh
 | 
				
			||||||
    for arg do
 | 
					  for arg; do
 | 
				
			||||||
    if
 | 
					    if
 | 
				
			||||||
      case $arg in #(
 | 
					      case $arg in #(
 | 
				
			||||||
      -*) false ;; # don't mess with options #(
 | 
					      -*) false ;; # don't mess with options #(
 | 
				
			||||||
              /?*)  t=${arg#/} t=/${t%%/*}              # looks like a POSIX filepath
 | 
					      /?*)
 | 
				
			||||||
                    [ -e "$t" ] ;;                      #(
 | 
					        t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
 | 
				
			||||||
 | 
					        [ -e "$t" ]
 | 
				
			||||||
 | 
					        ;; #(
 | 
				
			||||||
      *) false ;;
 | 
					      *) false ;;
 | 
				
			||||||
      esac
 | 
					      esac
 | 
				
			||||||
    then
 | 
					    then
 | 
				
			||||||
@ -205,6 +208,11 @@ set -- \
 | 
				
			|||||||
  org.gradle.wrapper.GradleWrapperMain \
 | 
					  org.gradle.wrapper.GradleWrapperMain \
 | 
				
			||||||
  "$@"
 | 
					  "$@"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Stop when "xargs" is not available.
 | 
				
			||||||
 | 
					if ! command -v xargs >/dev/null 2>&1; then
 | 
				
			||||||
 | 
					  die "xargs is not available"
 | 
				
			||||||
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Use "xargs" to parse quoted args.
 | 
					# Use "xargs" to parse quoted args.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
 | 
					# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										10
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								gradlew.bat
									
									
									
									
										vendored
									
									
								
							@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
set JAVA_EXE=java.exe
 | 
					set JAVA_EXE=java.exe
 | 
				
			||||||
%JAVA_EXE% -version >NUL 2>&1
 | 
					%JAVA_EXE% -version >NUL 2>&1
 | 
				
			||||||
if "%ERRORLEVEL%" == "0" goto execute
 | 
					if %ERRORLEVEL% equ 0 goto execute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
echo.
 | 
					echo.
 | 
				
			||||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
					echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
 | 
				
			||||||
@ -75,13 +75,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
:end
 | 
					:end
 | 
				
			||||||
@rem End local scope for the variables with windows NT shell
 | 
					@rem End local scope for the variables with windows NT shell
 | 
				
			||||||
if "%ERRORLEVEL%"=="0" goto mainEnd
 | 
					if %ERRORLEVEL% equ 0 goto mainEnd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:fail
 | 
					:fail
 | 
				
			||||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
 | 
					rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
 | 
				
			||||||
rem the _cmd.exe /c_ return code!
 | 
					rem the _cmd.exe /c_ return code!
 | 
				
			||||||
if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
 | 
					set EXIT_CODE=%ERRORLEVEL%
 | 
				
			||||||
exit /b 1
 | 
					if %EXIT_CODE% equ 0 set EXIT_CODE=1
 | 
				
			||||||
 | 
					if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
 | 
				
			||||||
 | 
					exit /b %EXIT_CODE%
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:mainEnd
 | 
					:mainEnd
 | 
				
			||||||
if "%OS%"=="Windows_NT" endlocal
 | 
					if "%OS%"=="Windows_NT" endlocal
 | 
				
			||||||
 | 
				
			|||||||
@ -70,7 +70,7 @@ apply from: "chooseBackend.gradle"
 | 
				
			|||||||
rootProject.name = "Cavis"
 | 
					rootProject.name = "Cavis"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
 | 
					enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
 | 
				
			||||||
enableFeaturePreview("VERSION_CATALOGS")
 | 
					//enableFeaturePreview("VERSION_CATALOGS") //only needed for gradle <8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sourceControl {
 | 
					sourceControl {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user