Fixed missing imports

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-07-27 09:03:58 +02:00
parent deb436036b
commit ca127d8b88
9 changed files with 186 additions and 115 deletions

View File

@ -21,14 +21,14 @@
package net.brutex.gan; package net.brutex.gan;
import java.awt.BorderLayout; import java.awt.*;
import java.awt.Dimension;
import java.awt.GridLayout;
import java.awt.Image;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Random; import java.util.Random;
import java.util.UUID;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon; import javax.swing.ImageIcon;
import javax.swing.JFrame; import javax.swing.JFrame;
import javax.swing.JLabel; import javax.swing.JLabel;
@ -82,9 +82,9 @@ public class App {
private static final int X_DIM = 20 ; private static final int X_DIM = 20 ;
private static final int Y_DIM = 20; private static final int Y_DIM = 20;
private static final int CHANNELS = 3; private static final int CHANNELS = 1;
private static final int batchSize = 50; private static final int batchSize = 1;
private static final int INPUT = 128; private static final int INPUT = 10;
private static final int OUTPUT_PER_PANEL = 16; private static final int OUTPUT_PER_PANEL = 16;
@ -96,6 +96,8 @@ public class App {
private static JPanel panel; private static JPanel panel;
private static JPanel panel2; private static JPanel panel2;
private static final String OUTPUT_DIR = "C:/temp/output/";
private static LayerConfiguration[] genLayers() { private static LayerConfiguration[] genLayers() {
return new LayerConfiguration[] { return new LayerConfiguration[] {
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
@ -103,6 +105,7 @@ public class App {
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
@ -207,6 +210,12 @@ public class App {
@Test @Test
public void runTest() throws Exception { public void runTest() throws Exception {
if(! log.isDebugEnabled()) {
log.info("Logging is not set to DEBUG");
}
else {
log.info("Logging is set to DEBUG");
}
main(); main();
} }
@ -240,25 +249,25 @@ public class App {
MultiLayerNetwork gan = new MultiLayerNetwork(gan()); MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init(); log.debug("Generator network: {}", gen); gen.init(); log.debug("Generator network: {}", gen);
dis.init(); log.debug("Discriminator network: {}", dis); dis.init(); log.debug("Discriminator network: {}", dis);
gan.init(); log.debug("Complete GAN network: {}", gan); gan.init(); log.info("Complete GAN network: {}", gan);
copyParams(gen, dis, gan); copyParams(gen, dis, gan);
gen.addTrainingListeners(new PerformanceListener(15, true)); //gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
//dis.addTrainingListeners(new PerformanceListener(10, true)); dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
//gan.addTrainingListeners(new PerformanceListener(10, true)); gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
//gan.addTrainingListeners(new ScoreToChartListener("gan")); //gan.addTrainingListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis")); //dis.setListeners(new ScoreToChartListener("dis"));
System.out.println(gan.toString()); //System.out.println(gan.toString());
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)); //gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1))); //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
//trainData.reset(); //trainData.reset();
int j = 0; int j = 0;
for (int i = 0; i < 201; i++) { //epoch for (int i = 0; i < 51; i++) { //epoch
while (trainData.hasNext()) { while (trainData.hasNext()) {
j++; j++;
@ -282,6 +291,9 @@ public class App {
// int batchSize = (int) real.shape()[0]; // int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM); INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM); fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
@ -299,11 +311,11 @@ public class App {
updateGan(gen, dis, gan); updateGan(gen, dis, gan);
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1))); //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1))); gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
//Visualize and reporting //Visualize and reporting
if (j % 10 == 1) { if (j % 10 == 1) {
System.out.println("Iteration " + j + " Visualizing..."); System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize]; INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
@ -330,11 +342,12 @@ public class App {
} }
// Copy the GANs generator to gen. // Copy the GANs generator to gen.
updateGen(gen, gan); updateGen(gen, gan);
gen.save(new File("mnist-mlp-generator.dlj"));
} }
gen.save(new File("mnist-mlp-generator.dlj"));
} }
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
@ -342,10 +355,10 @@ public class App {
for (int i = 0; i < gan.getLayers().length; i++) { for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) { if (i < genLayerCount) {
if(gan.getLayer(i).getParams() != null) if(gan.getLayer(i).getParams() != null)
gen.getLayer(i).setParams(gan.getLayer(i).getParams()); gan.getLayer(i).setParams(gen.getLayer(i).getParams());
} else { } else {
if(gan.getLayer(i).getParams() != null) if(gan.getLayer(i).getParams() != null)
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams()); gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
} }
} }
} }
@ -411,14 +424,50 @@ public class App {
} }
} }
} }
ImageIcon orig = new ImageIcon(bi); ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT); Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled); ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled); return new JLabel(scaled);
} }
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
} }

View File

@ -25,7 +25,7 @@
# Default logging detail level for all instances of SimpleLogger. # Default logging detail level for all instances of SimpleLogger.
# Must be one of ("trace", "debug", "info", "warn", or "error"). # Must be one of ("trace", "debug", "info", "warn", or "error").
# If not specified, defaults to "info". # If not specified, defaults to "info".
org.slf4j.simpleLogger.defaultLogLevel=trace org.slf4j.simpleLogger.defaultLogLevel=debug
# Logging detail level for a SimpleLogger instance named "xxxxx". # Logging detail level for a SimpleLogger instance named "xxxxx".
# Must be one of ("trace", "debug", "info", "warn", or "error"). # Must be one of ("trace", "debug", "info", "warn", or "error").
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
# If the format is not specified or is invalid, the default format is used. # If the format is not specified or is invalid, the default format is used.
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z. # The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z #org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss #org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
# Set to true if you want to output the current thread name. # Set to true if you want to output the current thread name.
# Defaults to true. # Defaults to true.
org.slf4j.simpleLogger.showThreadName=true #org.slf4j.simpleLogger.showThreadName=true

View File

@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.serde.CavisMapper;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;

View File

@ -55,6 +55,8 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
private boolean reportEtl = true; private boolean reportEtl = true;
private boolean reportTime = true; private boolean reportTime = true;
private final String name;
public PerformanceListener(int frequency) { public PerformanceListener(int frequency) {
@ -66,14 +68,22 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
} }
public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) { public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) {
this(frequency, reportScore, reportGC, "no-name");
}
public PerformanceListener(int frequency, boolean reportScore, boolean reportGC, String name) {
Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency); Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency);
this.frequency = frequency; this.frequency = frequency;
this.reportScore = reportScore; this.reportScore = reportScore;
this.reportGC = reportGC; this.reportGC = reportGC;
this.name = name;
lastTime.set(System.currentTimeMillis()); lastTime.set(System.currentTimeMillis());
} }
public PerformanceListener(int frequency, boolean reportScore, String name) {
this(frequency, reportScore, false, name);
}
@Override @Override
public void iterationDone(IModel model, int iteration, int epoch) { public void iterationDone(IModel model, int iteration, int epoch) {
// we update lastTime on every iteration // we update lastTime on every iteration
@ -116,6 +126,7 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
builder.append("Name: '"+name+"'");
if (Nd4j.getAffinityManager().getNumberOfDevices() > 1) if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; "); builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; ");

Binary file not shown.

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-8.2.1-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

16
gradlew vendored
View File

@ -116,7 +116,6 @@ esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM. # Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ]; then if [ -n "$JAVA_HOME" ]; then
if [ -x "$JAVA_HOME/jre/sh/java" ]; then if [ -x "$JAVA_HOME/jre/sh/java" ]; then
@ -145,12 +144,14 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
max*) max*)
MAX_FD=$(ulimit -H -n) || MAX_FD=$(ulimit -H -n) ||
warn "Could not query maximum file descriptor limit" warn "Could not query maximum file descriptor limit"
;;
esac esac
case $MAX_FD in #( case $MAX_FD in #(
'' | soft) : ;; #( '' | soft) : ;; #(
*) *)
ulimit -n "$MAX_FD" || ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD" warn "Could not set maximum file descriptor limit to $MAX_FD"
;;
esac esac
fi fi
@ -170,12 +171,14 @@ if "$cygwin" || "$msys" ; then
JAVACMD=$(cygpath --unix "$JAVACMD") JAVACMD=$(cygpath --unix "$JAVACMD")
# Now convert the arguments - kludge to limit ourselves to /bin/sh # Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do for arg; do
if if
case $arg in #( case $arg in #(
-*) false ;; # don't mess with options #( -*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath /?*)
[ -e "$t" ] ;; #( t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ]
;; #(
*) false ;; *) false ;;
esac esac
then then
@ -205,6 +208,11 @@ set -- \
org.gradle.wrapper.GradleWrapperMain \ org.gradle.wrapper.GradleWrapperMain \
"$@" "$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1; then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args. # Use "xargs" to parse quoted args.
# #
# With -n1 it outputs one arg per line, with the quotes and backslashes removed. # With -n1 it outputs one arg per line, with the quotes and backslashes removed.

10
gradlew.bat vendored
View File

@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1 %JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute if %ERRORLEVEL% equ 0 goto execute
echo. echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
@ -75,13 +75,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
:end :end
@rem End local scope for the variables with windows NT shell @rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd if %ERRORLEVEL% equ 0 goto mainEnd
:fail :fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code! rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 set EXIT_CODE=%ERRORLEVEL%
exit /b 1 if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd :mainEnd
if "%OS%"=="Windows_NT" endlocal if "%OS%"=="Windows_NT" endlocal

View File

@ -70,7 +70,7 @@ apply from: "chooseBackend.gradle"
rootProject.name = "Cavis" rootProject.name = "Cavis"
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
enableFeaturePreview("VERSION_CATALOGS") //enableFeaturePreview("VERSION_CATALOGS") //only needed for gradle <8
sourceControl { sourceControl {