parent
deb436036b
commit
ca127d8b88
|
@ -21,14 +21,14 @@
|
|||
|
||||
package net.brutex.gan;
|
||||
|
||||
import java.awt.BorderLayout;
|
||||
import java.awt.Dimension;
|
||||
import java.awt.GridLayout;
|
||||
import java.awt.Image;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import java.util.UUID;
|
||||
import javax.imageio.ImageIO;
|
||||
import javax.swing.ImageIcon;
|
||||
import javax.swing.JFrame;
|
||||
import javax.swing.JLabel;
|
||||
|
@ -82,9 +82,9 @@ public class App {
|
|||
|
||||
private static final int X_DIM = 20 ;
|
||||
private static final int Y_DIM = 20;
|
||||
private static final int CHANNELS = 3;
|
||||
private static final int batchSize = 50;
|
||||
private static final int INPUT = 128;
|
||||
private static final int CHANNELS = 1;
|
||||
private static final int batchSize = 1;
|
||||
private static final int INPUT = 10;
|
||||
|
||||
private static final int OUTPUT_PER_PANEL = 16;
|
||||
|
||||
|
@ -96,6 +96,8 @@ public class App {
|
|||
private static JPanel panel;
|
||||
private static JPanel panel2;
|
||||
|
||||
private static final String OUTPUT_DIR = "C:/temp/output/";
|
||||
|
||||
private static LayerConfiguration[] genLayers() {
|
||||
return new LayerConfiguration[] {
|
||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
||||
|
@ -103,6 +105,7 @@ public class App {
|
|||
|
||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DropoutLayer.builder(1 - 0.5).build(),
|
||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
|
||||
|
@ -207,6 +210,12 @@ public class App {
|
|||
|
||||
@Test
|
||||
public void runTest() throws Exception {
|
||||
if(! log.isDebugEnabled()) {
|
||||
log.info("Logging is not set to DEBUG");
|
||||
}
|
||||
else {
|
||||
log.info("Logging is set to DEBUG");
|
||||
}
|
||||
main();
|
||||
}
|
||||
|
||||
|
@ -240,25 +249,25 @@ public class App {
|
|||
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
||||
gen.init(); log.debug("Generator network: {}", gen);
|
||||
dis.init(); log.debug("Discriminator network: {}", dis);
|
||||
gan.init(); log.debug("Complete GAN network: {}", gan);
|
||||
gan.init(); log.info("Complete GAN network: {}", gan);
|
||||
|
||||
|
||||
copyParams(gen, dis, gan);
|
||||
|
||||
gen.addTrainingListeners(new PerformanceListener(15, true));
|
||||
//dis.addTrainingListeners(new PerformanceListener(10, true));
|
||||
//gan.addTrainingListeners(new PerformanceListener(10, true));
|
||||
//gen.addTrainingListeners(new PerformanceListener(15, true, "GEN"));
|
||||
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
||||
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
||||
//dis.setListeners(new ScoreToChartListener("dis"));
|
||||
|
||||
System.out.println(gan.toString());
|
||||
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
||||
//System.out.println(gan.toString());
|
||||
//gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
||||
|
||||
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
|
||||
//trainData.reset();
|
||||
|
||||
int j = 0;
|
||||
for (int i = 0; i < 201; i++) { //epoch
|
||||
for (int i = 0; i < 51; i++) { //epoch
|
||||
while (trainData.hasNext()) {
|
||||
j++;
|
||||
|
||||
|
@ -282,6 +291,9 @@ public class App {
|
|||
// int batchSize = (int) real.shape()[0];
|
||||
|
||||
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
|
||||
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
|
||||
|
||||
|
||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
|
||||
|
||||
|
@ -299,11 +311,11 @@ public class App {
|
|||
updateGan(gen, dis, gan);
|
||||
|
||||
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
|
||||
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
|
||||
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1)));
|
||||
|
||||
//Visualize and reporting
|
||||
if (j % 10 == 1) {
|
||||
System.out.println("Iteration " + j + " Visualizing...");
|
||||
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
|
||||
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
|
||||
|
||||
|
||||
|
@ -330,11 +342,12 @@ public class App {
|
|||
}
|
||||
// Copy the GANs generator to gen.
|
||||
updateGen(gen, gan);
|
||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||
}
|
||||
|
||||
|
||||
|
||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||
|
||||
}
|
||||
|
||||
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||
|
@ -342,10 +355,10 @@ public class App {
|
|||
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||
if (i < genLayerCount) {
|
||||
if(gan.getLayer(i).getParams() != null)
|
||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||
gan.getLayer(i).setParams(gen.getLayer(i).getParams());
|
||||
} else {
|
||||
if(gan.getLayer(i).getParams() != null)
|
||||
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
||||
gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -411,14 +424,50 @@ public class App {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
ImageIcon orig = new ImageIcon(bi);
|
||||
|
||||
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
||||
|
||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||
|
||||
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||
return new JLabel(scaled);
|
||||
|
||||
}
|
||||
|
||||
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||
|
||||
try {
|
||||
// Save the images to disk
|
||||
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||
|
||||
log.debug("Images saved successfully.");
|
||||
} catch (IOException e) {
|
||||
log.error("Error saving the images: {}", e.getMessage());
|
||||
}
|
||||
}
|
||||
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||
File directory = new File(outputDirectory);
|
||||
if (!directory.exists()) {
|
||||
directory.mkdir();
|
||||
}
|
||||
|
||||
File outputFile = new File(directory, fileName);
|
||||
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||
}
|
||||
|
||||
public static BufferedImage imageToBufferedImage(Image image) {
|
||||
if (image instanceof BufferedImage) {
|
||||
return (BufferedImage) image;
|
||||
}
|
||||
|
||||
// Create a buffered image with the same dimensions and transparency as the original image
|
||||
BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||
|
||||
// Draw the original image onto the buffered image
|
||||
Graphics2D g2d = bufferedImage.createGraphics();
|
||||
g2d.drawImage(image, 0, 0, null);
|
||||
g2d.dispose();
|
||||
|
||||
return bufferedImage;
|
||||
}
|
||||
|
||||
}
|
|
@ -25,7 +25,7 @@
|
|||
# Default logging detail level for all instances of SimpleLogger.
|
||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||
# If not specified, defaults to "info".
|
||||
org.slf4j.simpleLogger.defaultLogLevel=trace
|
||||
org.slf4j.simpleLogger.defaultLogLevel=debug
|
||||
|
||||
# Logging detail level for a SimpleLogger instance named "xxxxx".
|
||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||
|
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
|
|||
# If the format is not specified or is invalid, the default format is used.
|
||||
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
|
||||
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
|
||||
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
|
||||
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
|
||||
|
||||
# Set to true if you want to output the current thread name.
|
||||
# Defaults to true.
|
||||
org.slf4j.simpleLogger.showThreadName=true
|
||||
#org.slf4j.simpleLogger.showThreadName=true
|
|
@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
|
|||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
|
|
|
@ -55,6 +55,8 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
|
|||
private boolean reportEtl = true;
|
||||
private boolean reportTime = true;
|
||||
|
||||
private final String name;
|
||||
|
||||
|
||||
|
||||
public PerformanceListener(int frequency) {
|
||||
|
@ -66,14 +68,22 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
|
|||
}
|
||||
|
||||
public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) {
|
||||
this(frequency, reportScore, reportGC, "no-name");
|
||||
}
|
||||
public PerformanceListener(int frequency, boolean reportScore, boolean reportGC, String name) {
|
||||
Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency);
|
||||
this.frequency = frequency;
|
||||
this.reportScore = reportScore;
|
||||
this.reportGC = reportGC;
|
||||
this.name = name;
|
||||
|
||||
lastTime.set(System.currentTimeMillis());
|
||||
}
|
||||
|
||||
public PerformanceListener(int frequency, boolean reportScore, String name) {
|
||||
this(frequency, reportScore, false, name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterationDone(IModel model, int iteration, int epoch) {
|
||||
// we update lastTime on every iteration
|
||||
|
@ -116,6 +126,7 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali
|
|||
|
||||
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("Name: '"+name+"'");
|
||||
|
||||
if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
|
||||
builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; ");
|
||||
|
|
Binary file not shown.
|
@ -1,5 +1,5 @@
|
|||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2.1-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
|
|
@ -116,7 +116,6 @@ esac
|
|||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ]; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ]; then
|
||||
|
@ -145,12 +144,14 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
|||
max*)
|
||||
MAX_FD=$(ulimit -H -n) ||
|
||||
warn "Could not query maximum file descriptor limit"
|
||||
;;
|
||||
esac
|
||||
case $MAX_FD in #(
|
||||
'' | soft) : ;; #(
|
||||
*)
|
||||
ulimit -n "$MAX_FD" ||
|
||||
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
|
@ -170,12 +171,14 @@ if "$cygwin" || "$msys" ; then
|
|||
JAVACMD=$(cygpath --unix "$JAVACMD")
|
||||
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
for arg do
|
||||
for arg; do
|
||||
if
|
||||
case $arg in #(
|
||||
-*) false ;; # don't mess with options #(
|
||||
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
||||
[ -e "$t" ] ;; #(
|
||||
/?*)
|
||||
t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
||||
[ -e "$t" ]
|
||||
;; #(
|
||||
*) false ;;
|
||||
esac
|
||||
then
|
||||
|
@ -205,6 +208,11 @@ set -- \
|
|||
org.gradle.wrapper.GradleWrapperMain \
|
||||
"$@"
|
||||
|
||||
# Stop when "xargs" is not available.
|
||||
if ! command -v xargs >/dev/null 2>&1; then
|
||||
die "xargs is not available"
|
||||
fi
|
||||
|
||||
# Use "xargs" to parse quoted args.
|
||||
#
|
||||
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
|
||||
|
|
|
@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
|
|||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if "%ERRORLEVEL%" == "0" goto execute
|
||||
if %ERRORLEVEL% equ 0 goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
@ -75,13 +75,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
|||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
if %ERRORLEVEL% equ 0 goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
set EXIT_CODE=%ERRORLEVEL%
|
||||
if %EXIT_CODE% equ 0 set EXIT_CODE=1
|
||||
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
|
||||
exit /b %EXIT_CODE%
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
|
|
@ -70,7 +70,7 @@ apply from: "chooseBackend.gradle"
|
|||
rootProject.name = "Cavis"
|
||||
|
||||
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
|
||||
enableFeaturePreview("VERSION_CATALOGS")
|
||||
//enableFeaturePreview("VERSION_CATALOGS") //only needed for gradle <8
|
||||
|
||||
|
||||
sourceControl {
|
||||
|
|
Loading…
Reference in New Issue