diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index 4243c46e2..696f9241b 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -21,14 +21,14 @@ package net.brutex.gan; -import java.awt.BorderLayout; -import java.awt.Dimension; -import java.awt.GridLayout; -import java.awt.Image; +import java.awt.*; import java.awt.image.BufferedImage; import java.io.File; +import java.io.IOException; import java.util.Arrays; import java.util.Random; +import java.util.UUID; +import javax.imageio.ImageIO; import javax.swing.ImageIcon; import javax.swing.JFrame; import javax.swing.JLabel; @@ -82,9 +82,9 @@ public class App { private static final int X_DIM = 20 ; private static final int Y_DIM = 20; - private static final int CHANNELS = 3; - private static final int batchSize = 50; - private static final int INPUT = 128; + private static final int CHANNELS = 1; + private static final int batchSize = 1; + private static final int INPUT = 10; private static final int OUTPUT_PER_PANEL = 16; @@ -96,6 +96,8 @@ public class App { private static JPanel panel; private static JPanel panel2; + private static final String OUTPUT_DIR = "C:/temp/output/"; + private static LayerConfiguration[] genLayers() { return new LayerConfiguration[] { DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), @@ -103,6 +105,7 @@ public class App { DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(), + DropoutLayer.builder(1 - 0.5).build(), DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(), @@ -207,6 +210,12 @@ public class App { @Test public void runTest() throws Exception { + if(! log.isDebugEnabled()) { + log.info("Logging is not set to DEBUG"); + } + else { + log.info("Logging is set to DEBUG"); + } main(); } @@ -240,25 +249,25 @@ public class App { MultiLayerNetwork gan = new MultiLayerNetwork(gan()); gen.init(); log.debug("Generator network: {}", gen); dis.init(); log.debug("Discriminator network: {}", dis); - gan.init(); log.debug("Complete GAN network: {}", gan); + gan.init(); log.info("Complete GAN network: {}", gan); copyParams(gen, dis, gan); - gen.addTrainingListeners(new PerformanceListener(15, true)); - //dis.addTrainingListeners(new PerformanceListener(10, true)); - //gan.addTrainingListeners(new PerformanceListener(10, true)); + //gen.addTrainingListeners(new PerformanceListener(15, true, "GEN")); + dis.addTrainingListeners(new PerformanceListener(10, true, "DIS")); + gan.addTrainingListeners(new PerformanceListener(10, true, "GAN")); //gan.addTrainingListeners(new ScoreToChartListener("gan")); //dis.setListeners(new ScoreToChartListener("dis")); - System.out.println(gan.toString()); - gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)); + //System.out.println(gan.toString()); + //gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)); //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1))); //trainData.reset(); int j = 0; - for (int i = 0; i < 201; i++) { //epoch + for (int i = 0; i < 51; i++) { //epoch while (trainData.hasNext()) { j++; @@ -282,6 +291,9 @@ public class App { // int batchSize = (int) real.shape()[0]; INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM); + //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise + + INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM); @@ -299,11 +311,11 @@ public class App { updateGan(gen, dis, gan); //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1))); - gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1))); + gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.ones(batchSize, 1))); //Visualize and reporting if (j % 10 == 1) { - System.out.println("Iteration " + j + " Visualizing..."); + System.out.println("Epoch " + i + " Iteration " + j + " Visualizing..."); INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize]; @@ -330,11 +342,12 @@ public class App { } // Copy the GANs generator to gen. updateGen(gen, gan); + gen.save(new File("mnist-mlp-generator.dlj")); } - gen.save(new File("mnist-mlp-generator.dlj")); + } private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { @@ -342,10 +355,10 @@ public class App { for (int i = 0; i < gan.getLayers().length; i++) { if (i < genLayerCount) { if(gan.getLayer(i).getParams() != null) - gen.getLayer(i).setParams(gan.getLayer(i).getParams()); + gan.getLayer(i).setParams(gen.getLayer(i).getParams()); } else { if(gan.getLayer(i).getParams() != null) - dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams()); + gan.getLayer(i ).setParams(dis.getLayer(i- genLayerCount).getParams()); } } } @@ -411,14 +424,50 @@ public class App { } } } - ImageIcon orig = new ImageIcon(bi); - Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT); - ImageIcon scaled = new ImageIcon(imageScaled); - + if(! isOrig) saveImage(imageScaled, batchElement, isOrig); return new JLabel(scaled); } + + private static void saveImage(Image image, int batchElement, boolean isOrig) { + String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved + + try { + // Save the images to disk + saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png"); + + log.debug("Images saved successfully."); + } catch (IOException e) { + log.error("Error saving the images: {}", e.getMessage()); + } +} + private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException { + File directory = new File(outputDirectory); + if (!directory.exists()) { + directory.mkdir(); + } + + File outputFile = new File(directory, fileName); + ImageIO.write(imageToBufferedImage(image), "png", outputFile); + } + + public static BufferedImage imageToBufferedImage(Image image) { + if (image instanceof BufferedImage) { + return (BufferedImage) image; + } + + // Create a buffered image with the same dimensions and transparency as the original image + BufferedImage bufferedImage = new BufferedImage(image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB); + + // Draw the original image onto the buffered image + Graphics2D g2d = bufferedImage.createGraphics(); + g2d.drawImage(image, 0, 0, null); + g2d.dispose(); + + return bufferedImage; + } + } \ No newline at end of file diff --git a/brutex-extended-tests/src/test/resources/simplelogger.properties b/brutex-extended-tests/src/test/resources/simplelogger.properties index 711590236..6baf46f1b 100644 --- a/brutex-extended-tests/src/test/resources/simplelogger.properties +++ b/brutex-extended-tests/src/test/resources/simplelogger.properties @@ -25,7 +25,7 @@ # Default logging detail level for all instances of SimpleLogger. # Must be one of ("trace", "debug", "info", "warn", or "error"). # If not specified, defaults to "info". -org.slf4j.simpleLogger.defaultLogLevel=trace +org.slf4j.simpleLogger.defaultLogLevel=debug # Logging detail level for a SimpleLogger instance named "xxxxx". # Must be one of ("trace", "debug", "info", "warn", or "error"). @@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace # If the format is not specified or is invalid, the default format is used. # The default format is yyyy-MM-dd HH:mm:ss:SSS Z. #org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z -org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss +#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss # Set to true if you want to output the current thread name. # Defaults to true. -org.slf4j.simpleLogger.showThreadName=true \ No newline at end of file +#org.slf4j.simpleLogger.showThreadName=true \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java index 93a691c98..8fd8be54a 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java @@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.nn.conf.serde.CavisMapper; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java index 4ac091ecd..6af935f41 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java @@ -55,6 +55,8 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali private boolean reportEtl = true; private boolean reportTime = true; + private final String name; + public PerformanceListener(int frequency) { @@ -66,14 +68,22 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali } public PerformanceListener(int frequency, boolean reportScore, boolean reportGC) { + this(frequency, reportScore, reportGC, "no-name"); + } + public PerformanceListener(int frequency, boolean reportScore, boolean reportGC, String name) { Preconditions.checkArgument(frequency > 0, "Invalid frequency, must be > 0: Got " + frequency); this.frequency = frequency; this.reportScore = reportScore; this.reportGC = reportGC; + this.name = name; lastTime.set(System.currentTimeMillis()); } + public PerformanceListener(int frequency, boolean reportScore, String name) { + this(frequency, reportScore, false, name); + } + @Override public void iterationDone(IModel model, int iteration, int epoch) { // we update lastTime on every iteration @@ -116,6 +126,7 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali StringBuilder builder = new StringBuilder(); + builder.append("Name: '"+name+"'"); if (Nd4j.getAffinityManager().getNumberOfDevices() > 1) builder.append("Device: [").append(Nd4j.getAffinityManager().getDeviceForCurrentThread()).append("]; "); diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 7454180f2..249e5832f 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index ae04661ee..84a0b92f9 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.2.1-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 1b6c78733..64c63a782 100644 --- a/gradlew +++ b/gradlew @@ -69,18 +69,18 @@ app_path=$0 # Need this for daisy-chained symlinks. while - APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path - [ -h "$app_path" ] + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] do - ls=$( ls -ld "$app_path" ) - link=${ls#*' -> '} - case $link in #( - /*) app_path=$link ;; #( - *) app_path=$APP_HOME$link ;; - esac + ls=$(ls -ld "$app_path") + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit +APP_HOME=$(cd "${APP_HOME:-./}" && pwd -P) || exit APP_NAME="Gradle" APP_BASE_NAME=${0##*/} @@ -91,15 +91,15 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum -warn () { - echo "$*" +warn() { + echo "$*" } >&2 -die () { - echo - echo "$*" - echo - exit 1 +die() { + echo + echo "$*" + echo + exit 1 } >&2 # OS specific support (must be 'true' or 'false'). @@ -107,51 +107,52 @@ cygwin=false msys=false darwin=false nonstop=false -case "$( uname )" in #( - CYGWIN* ) cygwin=true ;; #( - Darwin* ) darwin=true ;; #( - MSYS* | MINGW* ) msys=true ;; #( - NONSTOP* ) nonstop=true ;; +case "$(uname)" in #( +CYGWIN*) cygwin=true ;; #( +Darwin*) darwin=true ;; #( +MSYS* | MINGW*) msys=true ;; #( +NONSTOP*) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar - # Determine the Java command to use to start the JVM. -if [ -n "$JAVA_HOME" ] ; then - if [ -x "$JAVA_HOME/jre/sh/java" ] ; then - # IBM's JDK on AIX uses strange locations for the executables - JAVACMD=$JAVA_HOME/jre/sh/java - else - JAVACMD=$JAVA_HOME/bin/java - fi - if [ ! -x "$JAVACMD" ] ; then - die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME +if [ -n "$JAVA_HOME" ]; then + if [ -x "$JAVA_HOME/jre/sh/java" ]; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ]; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME Please set the JAVA_HOME variable in your environment to match the location of your Java installation." - fi + fi else - JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + JAVACMD=java + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi # Increase the maximum file descriptors if we can. -if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then - case $MAX_FD in #( - max*) - MAX_FD=$( ulimit -H -n ) || - warn "Could not query maximum file descriptor limit" - esac - case $MAX_FD in #( - '' | soft) :;; #( - *) - ulimit -n "$MAX_FD" || - warn "Could not set maximum file descriptor limit to $MAX_FD" - esac +if ! "$cygwin" && ! "$darwin" && ! "$nonstop"; then + case $MAX_FD in #( + max*) + MAX_FD=$(ulimit -H -n) || + warn "Could not query maximum file descriptor limit" + ;; + esac + case $MAX_FD in #( + '' | soft) : ;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + ;; + esac fi # Collect all arguments for the java command, stacking in reverse order: @@ -163,34 +164,36 @@ fi # * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if "$cygwin" || "$msys" ; then - APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) - CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) +if "$cygwin" || "$msys"; then + APP_HOME=$(cygpath --path --mixed "$APP_HOME") + CLASSPATH=$(cygpath --path --mixed "$CLASSPATH") - JAVACMD=$( cygpath --unix "$JAVACMD" ) + JAVACMD=$(cygpath --unix "$JAVACMD") - # Now convert the arguments - kludge to limit ourselves to /bin/sh - for arg do - if - case $arg in #( - -*) false ;; # don't mess with options #( - /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath - [ -e "$t" ] ;; #( - *) false ;; - esac - then - arg=$( cygpath --path --ignore --mixed "$arg" ) - fi - # Roll the args list around exactly as many times as the number of - # args, so each arg winds up back in the position where it started, but - # possibly modified. - # - # NB: a `for` loop captures its iteration list before it begins, so - # changing the positional parameters here affects neither the number of - # iterations, nor the values presented in `arg`. - shift # remove old arg - set -- "$@" "$arg" # push replacement arg - done + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg; do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) + t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] + ;; #( + *) false ;; + esac + then + arg=$(cygpath --path --ignore --mixed "$arg") + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done fi # Collect all arguments for the java command; @@ -200,10 +203,15 @@ fi # * put everything else in single quotes, so that it's not re-expanded. set -- \ - "-Dorg.gradle.appname=$APP_BASE_NAME" \ - -classpath "$CLASSPATH" \ - org.gradle.wrapper.GradleWrapperMain \ - "$@" + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1; then + die "xargs is not available" +fi # Use "xargs" to parse quoted args. # @@ -225,10 +233,10 @@ set -- \ # eval "set -- $( - printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | - xargs -n1 | - sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | - tr '\n' ' ' - )" '"$@"' + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' +)" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index 107acd32c..f127cfd49 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,7 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute +if %ERRORLEVEL% equ 0 goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -75,13 +75,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/settings.gradle b/settings.gradle index d7875c751..edd780a86 100644 --- a/settings.gradle +++ b/settings.gradle @@ -70,7 +70,7 @@ apply from: "chooseBackend.gradle" rootProject.name = "Cavis" enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") -enableFeaturePreview("VERSION_CATALOGS") +//enableFeaturePreview("VERSION_CATALOGS") //only needed for gradle <8 sourceControl {