Compare commits

..

20 Commits

Author SHA1 Message Date
Brian Rosenberger dd151aec3f gan example
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:32:39 +02:00
Brian Rosenberger 1b3338f809 add test stage to linux cuda on docker build
Signed-off-by: brian <brian@brutex.de>
2023-07-27 22:15:16 +02:00
Brian Rosenberger 3d949c5348 Update lombok
Signed-off-by: brian <brian@brutex.de>
2023-07-27 22:02:30 +02:00
Brian Rosenberger 6930116c18 Downgrade gradle wrapper to 7.4.2 and upgrade javacpp-gradle plugin to 1.5.9
Signed-off-by: brian <brian@brutex.de>
2023-07-27 10:05:01 +02:00
Brian Rosenberger e27fb8422f Fixed missing imports
Signed-off-by: brian <brian@brutex.de>
2023-07-27 09:03:58 +02:00
Brian Rosenberger d0342fc939 Change jenkins pipeline credentials id for MAVEN
Signed-off-by: brian <brian@brutex.de>
2023-07-26 11:51:13 +02:00
Brian Rosenberger b34b96d929 Change jenkins pipeline credentials id for MAVEN
Signed-off-by: brian <brian@brutex.de>
2023-07-26 11:49:35 +02:00
Brian Rosenberger 8f51471a31 Change jenkins pipeline credentials id for MAVEN
Signed-off-by: brian <brian@brutex.de>
2023-07-26 11:14:04 +02:00
Brian Rosenberger dc5de40620 Fix build docker image to use CUDA 11.4.3 (was 11.4.0)
Signed-off-by: brian <brian@brutex.de>
2023-07-26 11:01:50 +02:00
Brian Rosenberger e834407b6e Fix build docker image to use CUDA 11.4.3 (was 11.4.0)
Signed-off-by: brian <brian@brutex.de>
2023-07-26 10:52:14 +02:00
Brian Rosenberger 4dc5a116b6 Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-07-25 10:59:46 +02:00
Brian Rosenberger 997143b9dd Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-17 09:12:47 +02:00
Brian Rosenberger 0bed17c97f Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-15 14:24:01 +02:00
Brian Rosenberger 8d73a7a410 Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-15 10:37:48 +02:00
Brian Rosenberger c758cf918f Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-08 19:12:46 +02:00
Brian Rosenberger 2c8c6d9624 Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-08 16:46:18 +02:00
Brian Rosenberger 0ba049885f Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-08 12:59:22 +02:00
Brian Rosenberger 345f55a003 Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-08 12:48:42 +02:00
Brian Rosenberger 1c39dbee52 Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-08 12:45:48 +02:00
Brian Rosenberger ea504bff41 Fixing tests
Signed-off-by: brian <brian@brutex.de>
2023-05-08 09:34:44 +02:00
23 changed files with 135 additions and 425 deletions

View File

@ -1,4 +1,4 @@
FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04 FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04
RUN apt-get update && \ RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
@ -11,10 +11,5 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.
rm cmake-3.24.2-linux-x86_64.sh rm cmake-3.24.2-linux-x86_64.sh
RUN echo "/usr/local/cuda/compat/" >> /etc/ld.so.conf.d/cuda-driver.conf
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
RUN ldconfig -p | grep cuda

View File

@ -65,7 +65,7 @@ pipeline {
}*/ }*/
stage('publish-linux-cpu') { stage('publish-linux-cpu') {
environment { environment {
MAVEN = credentials('Internal_Archiva') MAVEN = credentials('Internal Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }
@ -79,9 +79,4 @@ pipeline {
} }
} }
} }
post {
always {
junit '**/build/test-results/**/*.xml'
}
}
} }

View File

@ -59,10 +59,4 @@ pipeline {
} }
} }
} }
post {
always {
junit '**/build/test-results/**/*.xml'
}
}
} }

View File

@ -85,9 +85,4 @@ pipeline {
} }
} }
} }
post {
always {
junit '**/build/test-results/**/*.xml'
}
}
} }

View File

@ -26,7 +26,7 @@ pipeline {
dir '.docker' dir '.docker'
label 'linux && docker && cuda' label 'linux && docker && cuda'
//additionalBuildArgs '--build-arg version=1.0.2' //additionalBuildArgs '--build-arg version=1.0.2'
args '--gpus all' //needed for test only, you can build without GPU //args '--gpus all' --needed for test only, you can build without GPU
} }
} }
@ -72,10 +72,4 @@ pipeline {
} }
} }
} }
post {
always {
junit '**/build/test-results/**/*.xml'
}
}
} }

View File

@ -36,9 +36,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class LoadBackendTests { public class LoadBackendTests {
@Test @Test
public void loadBackend() throws NoSuchFieldException, IllegalAccessException { public void loadBackend() throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
// check if Nd4j is there // check if Nd4j is there
Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path")); //Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
sysPathsField.setAccessible(true); sysPathsField.setAccessible(true);
sysPathsField.set(null, null); sysPathsField.set(null, null);

View File

@ -16,7 +16,6 @@ import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
@ -123,7 +122,7 @@ public class App {
return conf; return conf;
} }
@Test @Tag("long-running") @Test
public void runTest() throws Exception { public void runTest() throws Exception {
App.main(null); App.main(null);
} }

View File

@ -37,8 +37,6 @@ import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.*; import org.datavec.image.transform.*;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -46,29 +44,25 @@ import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static net.brutex.gan.App2Config.BATCHSIZE;
@Slf4j @Slf4j
public class App2 { public class App2 {
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS; final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
static final float COLORSPACE = 255f;
static final int DIMENSIONS = 28; static final int DIMENSIONS = 28;
static final int CHANNELS = 1; static final int CHANNELS = 1;
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS; final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
final int OUTPUT_PER_PANEL = 10;
final boolean BIAS = true; final boolean BIAS = true;
static final int BATCHSIZE=128;
private JFrame frame2, frame; private JFrame frame2, frame;
static final String OUTPUT_DIR = "d:/out/"; static final String OUTPUT_DIR = "d:/out/";
@ -76,12 +70,12 @@ public class App2 {
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1); final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1); final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
@Test @Tag("long-running") @Test
void runTest() throws IOException { void runTest() throws IOException {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200); MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans3"), NativeImageLoader.getALLOWED_FORMATS()); FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 ); ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30); ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS); ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
@ -134,94 +128,12 @@ public class App2 {
log.info("Generator Summary:\n{}", gen.summary()); log.info("Generator Summary:\n{}", gen.summary());
log.info("GAN Summary:\n{}", gan.summary()); log.info("GAN Summary:\n{}", gan.summary());
dis.addTrainingListeners(new PerformanceListener(3, true, "DIS")); dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
//gen.addTrainingListeners(new PerformanceListener(3, true, "GEN")); //is never trained separately from GAN gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
gan.addTrainingListeners(new PerformanceListener(3, true, "GAN")); gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
/*
Thread vt =
new Thread(
new Runnable() {
@Override
public void run() {
while (true) {
visualize(0, 0, gen);
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
});
vt.start();
*/
App2Display display = new App2Display();
//Repack training data with new fake/real label. Original MNist has 10 labels, one for each digit
DataSet data = null;
int j = 0; int j = 0;
for(int i=0;i<App2Config.EPOCHS;i++) { for (int i = 0; i < 51; i++) { //epoch
log.info("Epoch {}", i);
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
while (trainData.hasNext()) {
j++;
INDArray real = trainData.next().getFeatures();
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1,
Nd4j.rand(BATCHSIZE, App2Config.INPUT));
//sigmoid output is -1 to 1
fake.addi(1f).divi(2f);
if (j % 50 == 1) {
display.visualize(new INDArray[] {fake}, App2Config.OUTPUT_PER_PANEL, false);
display.visualize(new INDArray[] {real}, App2Config.OUTPUT_PER_PANEL, true);
}
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
//if(real.length()/BATCHSIZE!=784) break;
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
}
//fit the discriminator
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
//reset the training data and fit the complete GAN
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_real));
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
log.info("Updated GAN's generator from gen.");
updateGen(gen, gan);
gen.save(new File("mnist-mlp-generator.dlj"));
}
//vt.stop();
/*
int j;
for (int i = 0; i < App2Config.EPOCHS; i++) { //epoch
j=0;
while (trainData.hasNext()) { while (trainData.hasNext()) {
j++; j++;
DataSet next = trainData.next(); DataSet next = trainData.next();
@ -299,8 +211,6 @@ public class App2 {
log.info("Updated GAN's generator from gen."); log.info("Updated GAN's generator from gen.");
gen.save(new File("mnist-mlp-generator.dlj")); gen.save(new File("mnist-mlp-generator.dlj"));
} }
*/
} }
@ -313,11 +223,110 @@ public class App2 {
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
if (isOrig) {
frame.setTitle("Viz Original");
} else {
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8));
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true);
frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
return frame;
}
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS >1) {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = DIMENSIONS * DIMENSIONS;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
bi.getRaster().setSample(x, y, c, f_pxl);
pxl++; //next item in INDArray
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return new JLabel(scaled);
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage;
if (CHANNELS > 1) {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
} else {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
}
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) { for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams()); gen.getLayer(i).setParams(gan.getLayer(i).getParams());
@ -331,41 +340,4 @@ public class App2 {
} }
} }
@Test
void testDiskriminator() throws IOException {
MultiLayerNetwork net = new MultiLayerNetwork(App2Config.discriminator());
net.init();
net.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
DataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
DataSet data = null;
for(int i=0;i<App2Config.EPOCHS;i++) {
log.info("Epoch {}", i);
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
while (trainData.hasNext()) {
INDArray real = trainData.next().getFeatures();
long[] l = new long[]{BATCHSIZE, real.length() / BATCHSIZE};
INDArray fake = Nd4j.rand(l );
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
if(real.length()/BATCHSIZE!=784) break;
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
}
net.fit(data);
trainData.reset();
}
long[] l = new long[]{BATCHSIZE, 784};
INDArray fake = Nd4j.rand(l );
DataSet fakeSet = new DataSet(fake, label_fake);
data = DataSet.merge(Arrays.asList(data, fakeSet));
ExistingDataSetIterator iter = new ExistingDataSetIterator(data);
Evaluation eval = net.evaluate(iter);
log.info( "\n" + eval.confusionMatrix());
}
} }

View File

@ -36,17 +36,10 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
public class App2Config { public class App2Config {
public static final int INPUT = 100; public static final int INPUT = 100;
public static final int BATCHSIZE=150;
public static final int X_DIM = 28; public static final int X_DIM = 28;
public static final int Y_DIM = 28; public static final int y_DIM = 28;
public static final int CHANNELS = 1; public static final int CHANNELS = 1;
public static final int EPOCHS = 50;
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build(); public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
public static final IUpdater UPDATER_DIS = Adam.builder().learningRate(0.02).beta1(0.5).build();
public static final boolean SHOW_GENERATED = true;
public static final float COLORSPACE = 255f;
final static int OUTPUT_PER_PANEL = 10;
static LayerConfiguration[] genLayerConfig() { static LayerConfiguration[] genLayerConfig() {
return new LayerConfiguration[] { return new LayerConfiguration[] {
@ -165,7 +158,7 @@ public class App2Config {
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100) .gradientNormalizationThreshold(100)
.seed(42) .seed(42)
.updater(UPDATER_DIS) .updater(UPDATER)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5))) // .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null) .weightNoise(null)

View File

@ -1,160 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import com.google.inject.Singleton;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.color.ColorSpace;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
import static net.brutex.gan.App2.OUTPUT_DIR;
import static net.brutex.gan.App2Config.*;
@Slf4j
@Singleton
public class App2Display {
private final JFrame frame = new JFrame();
private final App2GUI display = new App2GUI();
private final JPanel real_panel;
private final JPanel fake_panel;
public App2Display() {
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setContentPane(display.getOverall_panel());
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
frame.setVisible(true);
real_panel = display.getReal_panel();
fake_panel = display.getGen_panel();
real_panel.setLayout(new GridLayout(4, 4, 8, 8));
fake_panel.setLayout(new GridLayout(4, 4, 8, 8));
}
public void visualize(INDArray[] samples, int batchElements, boolean isOrig) {
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
final Image img = this.getImage(sample, i, isOrig);
final ImageIcon icon = new ImageIcon(img);
if(isOrig) {
if(real_panel.getComponents().length>=OUTPUT_PER_PANEL) {
real_panel.remove(0);
}
real_panel.add(new JLabel(icon));
} else {
if(fake_panel.getComponents().length>=OUTPUT_PER_PANEL) {
fake_panel.remove(0);
}
fake_panel.add(new JLabel(icon));
}
}
}
frame.pack();
frame.repaint();
}
public Image getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS >1) {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = X_DIM * Y_DIM;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < X_DIM; y++) { // step through the columns x
for (int x = 0; x < Y_DIM; x++) { //step through the rows y
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
if(isOrig) log.trace("'{}.'{} Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, isOrig ? "Real" : "Fake", x, y, c, pxl, f_pxl);
bi.getRaster().setSample(x, y, c, f_pxl);
pxl++; //next item in INDArray
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
//if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return imageScaled;
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage;
if (CHANNELS > 1) {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
} else {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
}
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
}

View File

@ -1,61 +0,0 @@
package net.brutex.gan;
import javax.swing.JPanel;
import javax.swing.JSplitPane;
import javax.swing.JLabel;
import java.awt.BorderLayout;
public class App2GUI extends JPanel {
/**
*
*/
private static final long serialVersionUID = 1L;
private JPanel overall_panel;
private JPanel real_panel;
private JPanel gen_panel;
/**
* Create the panel.
*/
public App2GUI() {
overall_panel = new JPanel();
add(overall_panel);
JSplitPane splitPane = new JSplitPane();
overall_panel.add(splitPane);
JPanel p1 = new JPanel();
splitPane.setLeftComponent(p1);
p1.setLayout(new BorderLayout(0, 0));
JLabel lblNewLabel = new JLabel("Generator");
p1.add(lblNewLabel, BorderLayout.NORTH);
gen_panel = new JPanel();
p1.add(gen_panel, BorderLayout.SOUTH);
JPanel p2 = new JPanel();
splitPane.setRightComponent(p2);
p2.setLayout(new BorderLayout(0, 0));
JLabel lblNewLabel_1 = new JLabel("Real");
p2.add(lblNewLabel_1, BorderLayout.NORTH);
real_panel = new JPanel();
p2.add(real_panel, BorderLayout.SOUTH);
}
public JPanel getOverall_panel() {
return overall_panel;
}
public JPanel getReal_panel() {
return real_panel;
}
public JPanel getGen_panel() {
return gen_panel;
}
}

View File

@ -31,7 +31,6 @@ import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
@ -101,7 +100,7 @@ public class MnistSimpleGAN {
return new MultiLayerNetwork(discConf); return new MultiLayerNetwork(discConf);
} }
@Test @Tag("long-running") @Test
public void runTest() throws Exception { public void runTest() throws Exception {
main(null); main(null);
} }

View File

@ -20,7 +20,7 @@ ext {
def javacv = [version:"1.5.7"] def javacv = [version:"1.5.7"]
def opencv = [version: "4.5.5"] def opencv = [version: "4.5.5"]
def leptonica = [version: "1.83.0"] //fix, only in javacpp 1.5.9 def leptonica = [version: "1.82.0"]
def junit = [version: "5.9.1"] def junit = [version: "5.9.1"]
def flatbuffers = [version: "1.10.0"] def flatbuffers = [version: "1.10.0"]
@ -118,8 +118,7 @@ dependencies {
api "org.bytedeco:javacv:${javacv.version}" api "org.bytedeco:javacv:${javacv.version}"
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}" api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}" api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
api "org.bytedeco:leptonica-platform:${leptonica.version}-1.5.9" api "org.bytedeco:leptonica-platform:${leptonica.version}-${javacpp.presetsVersion}"
api "org.bytedeco:leptonica:${leptonica.version}-1.5.9"
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}" api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}" api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}" api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
@ -130,7 +129,6 @@ dependencies {
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}" api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}" api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}" api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}"
api "org.bytedeco:mkl:2022.0-${javacpp.presetsVersion}"
api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}" api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}"
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}" api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}" api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"

View File

@ -28,8 +28,7 @@ dependencies {
implementation "org.bytedeco:javacv" implementation "org.bytedeco:javacv"
implementation "org.bytedeco:opencv" implementation "org.bytedeco:opencv"
implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget
//implementation "org.bytedeco:leptonica-platform" implementation "org.bytedeco:leptonica-platform"
implementation group: "org.bytedeco", name: "leptonica", classifier: buildTarget
implementation "org.bytedeco:hdf5-platform" implementation "org.bytedeco:hdf5-platform"
implementation "commons-io:commons-io" implementation "commons-io:commons-io"

View File

@ -46,7 +46,7 @@ import java.nio.ByteOrder;
import org.bytedeco.leptonica.*; import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.leptonica.*; import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*; import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;

View File

@ -52,9 +52,10 @@ import java.io.InputStream;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Random; import java.util.Random;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.bytedeco.leptonica.global.leptonica.*; import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;

View File

@ -78,7 +78,7 @@ class dnnTest {
* DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) * DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
*/ */
NeuralNetConfiguration network = NeuralNetConfiguration network =
NN.nn() NN.net()
.seed(42) .seed(42)
.updater(Adam.builder().learningRate(0.0002).beta1(0.5).build()) .updater(Adam.builder().learningRate(0.0002).beta1(0.5).build())
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)

View File

@ -31,7 +31,7 @@ dependencies {
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatasets implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatasets
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatavecIterators implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatavecIterators
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
implementation "org.apache.hadoop:hadoop-common:3.2.4" implementation "org.apache.hadoop:hadoop-common:3.2.0"
implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml" implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml"
implementation projects.cavisDatavec.cavisDatavecApi implementation projects.cavisDatavec.cavisDatavecApi
implementation projects.cavisDatavec.cavisDatavecSpark.cavisDatavecSparkCore implementation projects.cavisDatavec.cavisDatavecSpark.cavisDatavecSparkCore

View File

@ -52,8 +52,8 @@ buildscript {
classpath platform(project(":cavis-common-platform")) classpath platform(project(":cavis-common-platform"))
classpath group: "org.bytedeco", name: "openblas" classpath group: "org.bytedeco", name: "openblas"
classpath group: "org.bytedeco", name: "openblas", classifier: "${javacppPlatform}" classpath group: "org.bytedeco", name: "openblas", classifier: "${javacppPlatform}"
classpath group: "org.bytedeco", name:"mkl" classpath group: "org.bytedeco", name:"mkl-dnn"
classpath group: "org.bytedeco", name:"mkl", classifier: "${javacppPlatform}" classpath group: "org.bytedeco", name:"mkl-dnn", classifier: "${javacppPlatform}"
classpath group: "org.bytedeco", name: "javacpp" classpath group: "org.bytedeco", name: "javacpp"
classpath group: "org.bytedeco", name: "javacpp", classifier: "${javacppPlatform}" classpath group: "org.bytedeco", name: "javacpp", classifier: "${javacppPlatform}"
} }
@ -64,7 +64,7 @@ buildscript {
plugins { plugins {
id 'java-library' id 'java-library'
id 'org.bytedeco.gradle-javacpp-build' version "1.5.7" id 'org.bytedeco.gradle-javacpp-build' version "1.5.9"
id 'maven-publish' id 'maven-publish'
id 'signing' id 'signing'
} }
@ -336,12 +336,11 @@ chipList.each { thisChip ->
&& !project.getProperty("skip-native").equals("true") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) { && !project.getProperty("skip-native").equals("true") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute() def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute()
def outp = proc.text def outp = proc.text
def cl = "\"" + outp.replace("\\", "\\\\").trim() + "\"" def cl = outp.replace("\\", "\\\\").trim()
def currentCompiler = "" def currentCompiler = ""
doFirst{ doFirst{
currentCompiler = System.getProperty("org.bytedeco.javacpp.platform.compiler") currentCompiler = System.getProperty("org.bytedeco.javacpp.platform.compiler")
System.setProperty("org.bytedeco.javacpp.platform.compiler", cl) System.setProperty("org.bytedeco.javacpp.platform.compiler", cl)
System.setProperty("platform.compiler.cpp11", cl)
logger.quiet("Task ${thisTask.name} overrides compiler '${currentCompiler}' with '${cl}'.") logger.quiet("Task ${thisTask.name} overrides compiler '${currentCompiler}' with '${cl}'.")
} }
doLast { doLast {

View File

@ -102,18 +102,16 @@ ENDIF()
IF(${SD_EXTENSION} MATCHES "avx2") IF(${SD_EXTENSION} MATCHES "avx2")
message("Extension AVX2 enabled.") message("Extension AVX2 enabled.")
#-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 set(ARCH_TUNE "${ARCH_TUNE} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 -DSD_F16C=true -DF_AVX2=true")
set(ARCH_TUNE "${ARCH_TUNE} -DSD_F16C=true -DF_AVX2=true")
ELSEIF(${SD_EXTENSION} MATCHES "avx512") ELSEIF(${SD_EXTENSION} MATCHES "avx512")
message("Extension AVX512 enabled.") message("Extension AVX512 enabled.")
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked # we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
#-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
set(ARCH_TUNE "${ARCH_TUNE} -DSD_F16C=true -DF_AVX512=true")
ENDIF() ENDIF()
if (NOT WIN32) if (NOT WIN32)
# we don't want this definition for msvc # we don't want this definition for msvc
set(ARCH_TUNE "${ARCH_TUNE} -march=${SD_ARCH} -mtune=${ARCH_TYPE}") set(ARCH_TUNE "-march=${SD_ARCH} -mtune=${ARCH_TYPE}")
endif() endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD)

View File

@ -28,7 +28,7 @@ dependencies {
implementation group: "org.bytedeco", name: "tensorflow" implementation group: "org.bytedeco", name: "tensorflow"
testRuntimeOnly group: "org.bytedeco", name: "tensorflow", classifier: buildTarget testRuntimeOnly group: "org.bytedeco", name: "tensorflow", classifier: buildTarget
if(buildTarget.contains("windows") || buildTarget.contains("linux")) { if(buildTarget.contains("windows") || buildTarget.contains("linux")) {
testRuntimeOnly group: "org.bytedeco", name: 'tensorflow', classifier: "${buildTarget}-gpu", version: '' testRuntimeOnly group: "org.bytedeco", name: "tensorflow", classifier: "${buildTarget}-gpu"
} }
implementation "commons-io:commons-io" implementation "commons-io:commons-io"
implementation "com.google.code.gson:gson" implementation "com.google.code.gson:gson"

View File

@ -87,7 +87,7 @@ ext {
cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
cudaTestRuntime group: "org.bytedeco", name: "cuda" cudaTestRuntime group: "org.bytedeco", name: "cuda"
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
//cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist" cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements")) cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements"))
/* /*
cudaTestRuntime(project(":cavis-native:cavis-native-lib")) { cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {