diff --git a/brutex-extended-tests/src/test/java/net/brutex/ai/nd4j/tests/LoadBackendTests.java b/brutex-extended-tests/src/test/java/net/brutex/ai/nd4j/tests/LoadBackendTests.java index 4ce2844d5..4420a2b4a 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/ai/nd4j/tests/LoadBackendTests.java +++ b/brutex-extended-tests/src/test/java/net/brutex/ai/nd4j/tests/LoadBackendTests.java @@ -36,9 +36,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class LoadBackendTests { @Test - public void loadBackend() throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException { + public void loadBackend() throws NoSuchFieldException, IllegalAccessException { // check if Nd4j is there - //Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path")); + Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path")); final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); sysPathsField.setAccessible(true); sysPathsField.set(null, null); diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App2.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App2.java index 73e271fe1..1ed50b048 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App2.java @@ -37,6 +37,8 @@ import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.transform.*; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; +import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator; +import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -46,24 +48,27 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; +import static net.brutex.gan.App2Config.BATCHSIZE; + @Slf4j public class App2 { final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS; - static final float COLORSPACE = 255f; + static final int DIMENSIONS = 28; static final int CHANNELS = 1; final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS; - final int OUTPUT_PER_PANEL = 10; + final boolean BIAS = true; - static final int BATCHSIZE=128; + private JFrame frame2, frame; static final String OUTPUT_DIR = "d:/out/"; @@ -76,7 +81,7 @@ public class App2 { Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200); - FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS()); + FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans3"), NativeImageLoader.getALLOWED_FORMATS()); ImageTransform transform = new ColorConversionTransform(new Random(42), 7 ); ImageTransform transform2 = new ShowImageTransform("Tester", 30); ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS); @@ -129,12 +134,94 @@ public class App2 { log.info("Generator Summary:\n{}", gen.summary()); log.info("GAN Summary:\n{}", gan.summary()); - dis.addTrainingListeners(new PerformanceListener(10, true, "DIS")); - gen.addTrainingListeners(new PerformanceListener(10, true, "GEN")); - gan.addTrainingListeners(new PerformanceListener(10, true, "GAN")); + dis.addTrainingListeners(new PerformanceListener(3, true, "DIS")); + //gen.addTrainingListeners(new PerformanceListener(3, true, "GEN")); //is never trained separately from GAN + gan.addTrainingListeners(new PerformanceListener(3, true, "GAN")); +/* + Thread vt = + new Thread( + new Runnable() { + @Override + public void run() { + while (true) { + visualize(0, 0, gen); + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + }); + vt.start(); +*/ - int j = 0; - for (int i = 0; i < 51; i++) { //epoch + App2Display display = new App2Display(); + //Repack training data with new fake/real label. Original MNist has 10 labels, one for each digit + DataSet data = null; + int j =0; + for(int i=0;i1) { - bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels - } else { - bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels - } - final int imageSize = DIMENSIONS * DIMENSIONS; - final int offset = batchElement * imageSize; - int pxl = offset * CHANNELS; //where to start in the INDArray - - //Image in NCHW - channels first format - for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel - for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x - for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y - float f_pxl = tensor.getFloat(pxl) * COLORSPACE; - if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl); - bi.getRaster().setSample(x, y, c, f_pxl); - pxl++; //next item in INDArray - } - } - } - ImageIcon orig = new ImageIcon(bi); - Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT); - ImageIcon scaled = new ImageIcon(imageScaled); - if(! isOrig) saveImage(imageScaled, batchElement, isOrig); - return new JLabel(scaled); + */ } - private static void saveImage(Image image, int batchElement, boolean isOrig) { - String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved - try { - // Save the images to disk - saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png"); - log.debug("Images saved successfully."); - } catch (IOException e) { - log.error("Error saving the images: {}", e.getMessage()); - } - } - private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException { - File directory = new File(outputDirectory); - if (!directory.exists()) { - directory.mkdir(); - } - File outputFile = new File(directory, fileName); - ImageIO.write(imageToBufferedImage(image), "png", outputFile); - } - public static BufferedImage imageToBufferedImage(Image image) { - if (image instanceof BufferedImage) { - return (BufferedImage) image; - } - // Create a buffered image with the same dimensions and transparency as the original image - BufferedImage bufferedImage; - if (CHANNELS > 1) { - bufferedImage = - new BufferedImage( - image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB); - } else { - bufferedImage = - new BufferedImage( - image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY); - } - // Draw the original image onto the buffered image - Graphics2D g2d = bufferedImage.createGraphics(); - g2d.drawImage(image, 0, 0, null); - g2d.dispose(); - return bufferedImage; - } + + + + + + + + private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { for (int i = 0; i < gen.getLayers().length; i++) { gen.getLayer(i).setParams(gan.getLayer(i).getParams()); @@ -341,4 +331,41 @@ public class App2 { } } + + @Test + void testDiskriminator() throws IOException { + MultiLayerNetwork net = new MultiLayerNetwork(App2Config.discriminator()); + net.init(); + net.addTrainingListeners(new PerformanceListener(10, true, "DIS")); + DataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42); + + DataSet data = null; + for(int i=0;i=OUTPUT_PER_PANEL) { + real_panel.remove(0); + } + real_panel.add(new JLabel(icon)); + } else { + if(fake_panel.getComponents().length>=OUTPUT_PER_PANEL) { + fake_panel.remove(0); + } + fake_panel.add(new JLabel(icon)); + } + } + } + frame.pack(); + frame.repaint(); + } + + public Image getImage(INDArray tensor, int batchElement, boolean isOrig) { + final BufferedImage bi; + if(CHANNELS >1) { + bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels + } else { + bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels + } + final int imageSize = X_DIM * Y_DIM; + final int offset = batchElement * imageSize; + int pxl = offset * CHANNELS; //where to start in the INDArray + + //Image in NCHW - channels first format + for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel + for (int y = 0; y < X_DIM; y++) { // step through the columns x + for (int x = 0; x < Y_DIM; x++) { //step through the rows y + float f_pxl = tensor.getFloat(pxl) * COLORSPACE; + if(isOrig) log.trace("'{}.'{} Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, isOrig ? "Real" : "Fake", x, y, c, pxl, f_pxl); + bi.getRaster().setSample(x, y, c, f_pxl); + pxl++; //next item in INDArray + } + } + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT); + ImageIcon scaled = new ImageIcon(imageScaled); + //if(! isOrig) saveImage(imageScaled, batchElement, isOrig); + return imageScaled; + + } + + private static void saveImage(Image image, int batchElement, boolean isOrig) { + String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved + + try { + // Save the images to disk + saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png"); + + log.debug("Images saved successfully."); + } catch (IOException e) { + log.error("Error saving the images: {}", e.getMessage()); + } + } + private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException { + File directory = new File(outputDirectory); + if (!directory.exists()) { + directory.mkdir(); + } + + File outputFile = new File(directory, fileName); + ImageIO.write(imageToBufferedImage(image), "png", outputFile); + } + + public static BufferedImage imageToBufferedImage(Image image) { + if (image instanceof BufferedImage) { + return (BufferedImage) image; + } + + // Create a buffered image with the same dimensions and transparency as the original image + BufferedImage bufferedImage; + if (CHANNELS > 1) { + bufferedImage = + new BufferedImage( + image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB); + } else { + bufferedImage = + new BufferedImage( + image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY); + } + + // Draw the original image onto the buffered image + Graphics2D g2d = bufferedImage.createGraphics(); + g2d.drawImage(image, 0, 0, null); + g2d.dispose(); + + return bufferedImage; + } +} diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App2GUI.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App2GUI.java new file mode 100644 index 000000000..8478ab3e6 --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App2GUI.java @@ -0,0 +1,61 @@ +package net.brutex.gan; + +import javax.swing.JPanel; +import javax.swing.JSplitPane; +import javax.swing.JLabel; +import java.awt.BorderLayout; + +public class App2GUI extends JPanel { + + /** + * + */ + private static final long serialVersionUID = 1L; + private JPanel overall_panel; + private JPanel real_panel; + private JPanel gen_panel; + + /** + * Create the panel. + */ + public App2GUI() { + + overall_panel = new JPanel(); + add(overall_panel); + + JSplitPane splitPane = new JSplitPane(); + overall_panel.add(splitPane); + + JPanel p1 = new JPanel(); + splitPane.setLeftComponent(p1); + p1.setLayout(new BorderLayout(0, 0)); + + JLabel lblNewLabel = new JLabel("Generator"); + p1.add(lblNewLabel, BorderLayout.NORTH); + + gen_panel = new JPanel(); + p1.add(gen_panel, BorderLayout.SOUTH); + + JPanel p2 = new JPanel(); + splitPane.setRightComponent(p2); + p2.setLayout(new BorderLayout(0, 0)); + + JLabel lblNewLabel_1 = new JLabel("Real"); + p2.add(lblNewLabel_1, BorderLayout.NORTH); + + real_panel = new JPanel(); + p2.add(real_panel, BorderLayout.SOUTH); + + } + + + public JPanel getOverall_panel() { + return overall_panel; + } + public JPanel getReal_panel() { + return real_panel; + } + public JPanel getGen_panel() { + return gen_panel; + } +} diff --git a/brutex-extended-tests/src/test/resources/simplelogger.properties b/brutex-extended-tests/src/test/java/net/brutex/gan/simplelogger.properties similarity index 100% rename from brutex-extended-tests/src/test/resources/simplelogger.properties rename to brutex-extended-tests/src/test/java/net/brutex/gan/simplelogger.properties