Compare commits
7 Commits
Author | SHA1 | Date |
---|---|---|
Brian Rosenberger | 1c1ec071ef | |
Brian Rosenberger | 74ad5087c1 | |
Brian Rosenberger | acae3944ec | |
Brian Rosenberger | be7cd6b930 | |
Brian Rosenberger | 99aed71ffa | |
Brian Rosenberger | 2df8ea06e0 | |
Brian Rosenberger | 090c5ab2eb |
|
@ -11,5 +11,10 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.
|
||||||
rm cmake-3.24.2-linux-x86_64.sh
|
rm cmake-3.24.2-linux-x86_64.sh
|
||||||
|
|
||||||
|
|
||||||
|
RUN echo "/usr/local/cuda/compat/" >> /etc/ld.so.conf.d/cuda-driver.conf
|
||||||
|
|
||||||
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
||||||
|
|
||||||
|
RUN ldconfig -p | grep cuda
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ pipeline {
|
||||||
dir '.docker'
|
dir '.docker'
|
||||||
label 'linux && docker && cuda'
|
label 'linux && docker && cuda'
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||||
//args '--gpus all' --needed for test only, you can build without GPU
|
args '--gpus all' //needed for test only, you can build without GPU
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
public class LoadBackendTests {
|
public class LoadBackendTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadBackend() throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
|
public void loadBackend() throws NoSuchFieldException, IllegalAccessException {
|
||||||
// check if Nd4j is there
|
// check if Nd4j is there
|
||||||
//Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
|
Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
|
||||||
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||||
sysPathsField.setAccessible(true);
|
sysPathsField.setAccessible(true);
|
||||||
sysPathsField.set(null, null);
|
sysPathsField.set(null, null);
|
||||||
|
|
|
@ -37,6 +37,8 @@ import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.datavec.image.transform.*;
|
import org.datavec.image.transform.*;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -46,24 +48,27 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static net.brutex.gan.App2Config.BATCHSIZE;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class App2 {
|
public class App2 {
|
||||||
|
|
||||||
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
||||||
static final float COLORSPACE = 255f;
|
|
||||||
static final int DIMENSIONS = 28;
|
static final int DIMENSIONS = 28;
|
||||||
static final int CHANNELS = 1;
|
static final int CHANNELS = 1;
|
||||||
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
||||||
final int OUTPUT_PER_PANEL = 10;
|
|
||||||
|
|
||||||
final boolean BIAS = true;
|
final boolean BIAS = true;
|
||||||
|
|
||||||
static final int BATCHSIZE=128;
|
|
||||||
|
|
||||||
private JFrame frame2, frame;
|
private JFrame frame2, frame;
|
||||||
static final String OUTPUT_DIR = "d:/out/";
|
static final String OUTPUT_DIR = "d:/out/";
|
||||||
|
@ -76,7 +81,7 @@ public class App2 {
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
||||||
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
|
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans3"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||||
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||||
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
||||||
|
@ -129,12 +134,94 @@ public class App2 {
|
||||||
|
|
||||||
log.info("Generator Summary:\n{}", gen.summary());
|
log.info("Generator Summary:\n{}", gen.summary());
|
||||||
log.info("GAN Summary:\n{}", gan.summary());
|
log.info("GAN Summary:\n{}", gan.summary());
|
||||||
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
dis.addTrainingListeners(new PerformanceListener(3, true, "DIS"));
|
||||||
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
|
//gen.addTrainingListeners(new PerformanceListener(3, true, "GEN")); //is never trained separately from GAN
|
||||||
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
gan.addTrainingListeners(new PerformanceListener(3, true, "GAN"));
|
||||||
|
/*
|
||||||
|
Thread vt =
|
||||||
|
new Thread(
|
||||||
|
new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
while (true) {
|
||||||
|
visualize(0, 0, gen);
|
||||||
|
try {
|
||||||
|
Thread.sleep(10000);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
vt.start();
|
||||||
|
*/
|
||||||
|
|
||||||
int j = 0;
|
App2Display display = new App2Display();
|
||||||
for (int i = 0; i < 51; i++) { //epoch
|
//Repack training data with new fake/real label. Original MNist has 10 labels, one for each digit
|
||||||
|
DataSet data = null;
|
||||||
|
int j =0;
|
||||||
|
for(int i=0;i<App2Config.EPOCHS;i++) {
|
||||||
|
log.info("Epoch {}", i);
|
||||||
|
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
j++;
|
||||||
|
INDArray real = trainData.next().getFeatures();
|
||||||
|
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
|
||||||
|
|
||||||
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1,
|
||||||
|
Nd4j.rand(BATCHSIZE, App2Config.INPUT));
|
||||||
|
//sigmoid output is -1 to 1
|
||||||
|
fake.addi(1f).divi(2f);
|
||||||
|
|
||||||
|
if (j % 50 == 1) {
|
||||||
|
display.visualize(new INDArray[] {fake}, App2Config.OUTPUT_PER_PANEL, false);
|
||||||
|
display.visualize(new INDArray[] {real}, App2Config.OUTPUT_PER_PANEL, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, label_real);
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
|
||||||
|
//start next round if there are not enough images left to have a full batchsize dataset
|
||||||
|
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
|
||||||
|
log.warn("Your total number of input images is not a multiple of {}, "
|
||||||
|
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
//if(real.length()/BATCHSIZE!=784) break;
|
||||||
|
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
|
||||||
|
|
||||||
|
}
|
||||||
|
//fit the discriminator
|
||||||
|
dis.fit(data);
|
||||||
|
dis.fit(data);
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
//reset the training data and fit the complete GAN
|
||||||
|
if (trainData.resetSupported()) {
|
||||||
|
trainData.reset();
|
||||||
|
} else {
|
||||||
|
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||||
|
}
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_real));
|
||||||
|
|
||||||
|
if (trainData.resetSupported()) {
|
||||||
|
trainData.reset();
|
||||||
|
} else {
|
||||||
|
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("Updated GAN's generator from gen.");
|
||||||
|
updateGen(gen, gan);
|
||||||
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
|
}
|
||||||
|
//vt.stop();
|
||||||
|
|
||||||
|
/*
|
||||||
|
int j;
|
||||||
|
for (int i = 0; i < App2Config.EPOCHS; i++) { //epoch
|
||||||
|
j=0;
|
||||||
while (trainData.hasNext()) {
|
while (trainData.hasNext()) {
|
||||||
j++;
|
j++;
|
||||||
DataSet next = trainData.next();
|
DataSet next = trainData.next();
|
||||||
|
@ -212,122 +299,25 @@ public class App2 {
|
||||||
log.info("Updated GAN's generator from gen.");
|
log.info("Updated GAN's generator from gen.");
|
||||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
}
|
}
|
||||||
}
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
|
||||||
if (isOrig) {
|
|
||||||
frame.setTitle("Viz Original");
|
|
||||||
} else {
|
|
||||||
frame.setTitle("Generated");
|
|
||||||
}
|
|
||||||
|
|
||||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
|
||||||
frame.setLayout(new BorderLayout());
|
|
||||||
|
|
||||||
JPanel panelx = new JPanel();
|
|
||||||
|
|
||||||
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
|
||||||
for (INDArray sample : samples) {
|
|
||||||
for(int i = 0; i<batchElements; i++) {
|
|
||||||
panelx.add(getImage(sample, i, isOrig));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
frame.add(panelx, BorderLayout.CENTER);
|
|
||||||
frame.setVisible(true);
|
|
||||||
|
|
||||||
frame.revalidate();
|
|
||||||
frame.setMinimumSize(new Dimension(300, 20));
|
|
||||||
frame.pack();
|
|
||||||
return frame;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
|
||||||
final BufferedImage bi;
|
|
||||||
if(CHANNELS >1) {
|
|
||||||
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
|
||||||
} else {
|
|
||||||
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
|
||||||
}
|
|
||||||
final int imageSize = DIMENSIONS * DIMENSIONS;
|
|
||||||
final int offset = batchElement * imageSize;
|
|
||||||
int pxl = offset * CHANNELS; //where to start in the INDArray
|
|
||||||
|
|
||||||
//Image in NCHW - channels first format
|
|
||||||
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
|
||||||
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
|
|
||||||
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
|
|
||||||
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
|
||||||
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
|
|
||||||
bi.getRaster().setSample(x, y, c, f_pxl);
|
|
||||||
pxl++; //next item in INDArray
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ImageIcon orig = new ImageIcon(bi);
|
|
||||||
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
|
|
||||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
|
||||||
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
|
||||||
return new JLabel(scaled);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
|
||||||
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Save the images to disk
|
|
||||||
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
|
||||||
|
|
||||||
log.debug("Images saved successfully.");
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("Error saving the images: {}", e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
|
||||||
File directory = new File(outputDirectory);
|
|
||||||
if (!directory.exists()) {
|
|
||||||
directory.mkdir();
|
|
||||||
}
|
|
||||||
|
|
||||||
File outputFile = new File(directory, fileName);
|
|
||||||
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static BufferedImage imageToBufferedImage(Image image) {
|
|
||||||
if (image instanceof BufferedImage) {
|
|
||||||
return (BufferedImage) image;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a buffered image with the same dimensions and transparency as the original image
|
|
||||||
BufferedImage bufferedImage;
|
|
||||||
if (CHANNELS > 1) {
|
|
||||||
bufferedImage =
|
|
||||||
new BufferedImage(
|
|
||||||
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
|
||||||
} else {
|
|
||||||
bufferedImage =
|
|
||||||
new BufferedImage(
|
|
||||||
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Draw the original image onto the buffered image
|
|
||||||
Graphics2D g2d = bufferedImage.createGraphics();
|
|
||||||
g2d.drawImage(image, 0, 0, null);
|
|
||||||
g2d.dispose();
|
|
||||||
|
|
||||||
return bufferedImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
|
@ -341,4 +331,41 @@ public class App2 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testDiskriminator() throws IOException {
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(App2Config.discriminator());
|
||||||
|
net.init();
|
||||||
|
net.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||||
|
DataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
|
|
||||||
|
DataSet data = null;
|
||||||
|
for(int i=0;i<App2Config.EPOCHS;i++) {
|
||||||
|
log.info("Epoch {}", i);
|
||||||
|
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
INDArray real = trainData.next().getFeatures();
|
||||||
|
long[] l = new long[]{BATCHSIZE, real.length() / BATCHSIZE};
|
||||||
|
INDArray fake = Nd4j.rand(l );
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, label_real);
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
if(real.length()/BATCHSIZE!=784) break;
|
||||||
|
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
|
||||||
|
|
||||||
|
}
|
||||||
|
net.fit(data);
|
||||||
|
trainData.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
long[] l = new long[]{BATCHSIZE, 784};
|
||||||
|
INDArray fake = Nd4j.rand(l );
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
data = DataSet.merge(Arrays.asList(data, fakeSet));
|
||||||
|
ExistingDataSetIterator iter = new ExistingDataSetIterator(data);
|
||||||
|
Evaluation eval = net.evaluate(iter);
|
||||||
|
log.info( "\n" + eval.confusionMatrix());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,10 +36,17 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
public class App2Config {
|
public class App2Config {
|
||||||
|
|
||||||
public static final int INPUT = 100;
|
public static final int INPUT = 100;
|
||||||
|
public static final int BATCHSIZE=150;
|
||||||
public static final int X_DIM = 28;
|
public static final int X_DIM = 28;
|
||||||
public static final int y_DIM = 28;
|
public static final int Y_DIM = 28;
|
||||||
public static final int CHANNELS = 1;
|
public static final int CHANNELS = 1;
|
||||||
|
public static final int EPOCHS = 50;
|
||||||
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
||||||
|
public static final IUpdater UPDATER_DIS = Adam.builder().learningRate(0.02).beta1(0.5).build();
|
||||||
|
public static final boolean SHOW_GENERATED = true;
|
||||||
|
public static final float COLORSPACE = 255f;
|
||||||
|
|
||||||
|
final static int OUTPUT_PER_PANEL = 10;
|
||||||
|
|
||||||
static LayerConfiguration[] genLayerConfig() {
|
static LayerConfiguration[] genLayerConfig() {
|
||||||
return new LayerConfiguration[] {
|
return new LayerConfiguration[] {
|
||||||
|
@ -158,7 +165,7 @@ public class App2Config {
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(100)
|
.gradientNormalizationThreshold(100)
|
||||||
.seed(42)
|
.seed(42)
|
||||||
.updater(UPDATER)
|
.updater(UPDATER_DIS)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
.weightNoise(null)
|
.weightNoise(null)
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import com.google.inject.Singleton;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import javax.imageio.ImageIO;
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.color.ColorSpace;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static net.brutex.gan.App2.OUTPUT_DIR;
|
||||||
|
import static net.brutex.gan.App2Config.*;
|
||||||
|
@Slf4j
|
||||||
|
@Singleton
|
||||||
|
public class App2Display {
|
||||||
|
|
||||||
|
private final JFrame frame = new JFrame();
|
||||||
|
private final App2GUI display = new App2GUI();
|
||||||
|
|
||||||
|
private final JPanel real_panel;
|
||||||
|
private final JPanel fake_panel;
|
||||||
|
|
||||||
|
|
||||||
|
public App2Display() {
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setContentPane(display.getOverall_panel());
|
||||||
|
frame.setMinimumSize(new Dimension(300, 20));
|
||||||
|
frame.pack();
|
||||||
|
frame.setVisible(true);
|
||||||
|
real_panel = display.getReal_panel();
|
||||||
|
fake_panel = display.getGen_panel();
|
||||||
|
real_panel.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
|
fake_panel.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void visualize(INDArray[] samples, int batchElements, boolean isOrig) {
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
for(int i = 0; i<batchElements; i++) {
|
||||||
|
final Image img = this.getImage(sample, i, isOrig);
|
||||||
|
final ImageIcon icon = new ImageIcon(img);
|
||||||
|
if(isOrig) {
|
||||||
|
if(real_panel.getComponents().length>=OUTPUT_PER_PANEL) {
|
||||||
|
real_panel.remove(0);
|
||||||
|
}
|
||||||
|
real_panel.add(new JLabel(icon));
|
||||||
|
} else {
|
||||||
|
if(fake_panel.getComponents().length>=OUTPUT_PER_PANEL) {
|
||||||
|
fake_panel.remove(0);
|
||||||
|
}
|
||||||
|
fake_panel.add(new JLabel(icon));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
frame.pack();
|
||||||
|
frame.repaint();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Image getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||||
|
final BufferedImage bi;
|
||||||
|
if(CHANNELS >1) {
|
||||||
|
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
||||||
|
} else {
|
||||||
|
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
||||||
|
}
|
||||||
|
final int imageSize = X_DIM * Y_DIM;
|
||||||
|
final int offset = batchElement * imageSize;
|
||||||
|
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||||
|
|
||||||
|
//Image in NCHW - channels first format
|
||||||
|
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||||
|
for (int y = 0; y < X_DIM; y++) { // step through the columns x
|
||||||
|
for (int x = 0; x < Y_DIM; x++) { //step through the rows y
|
||||||
|
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
||||||
|
if(isOrig) log.trace("'{}.'{} Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, isOrig ? "Real" : "Fake", x, y, c, pxl, f_pxl);
|
||||||
|
bi.getRaster().setSample(x, y, c, f_pxl);
|
||||||
|
pxl++; //next item in INDArray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
//if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||||
|
return imageScaled;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||||
|
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Save the images to disk
|
||||||
|
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||||
|
|
||||||
|
log.debug("Images saved successfully.");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("Error saving the images: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||||
|
File directory = new File(outputDirectory);
|
||||||
|
if (!directory.exists()) {
|
||||||
|
directory.mkdir();
|
||||||
|
}
|
||||||
|
|
||||||
|
File outputFile = new File(directory, fileName);
|
||||||
|
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage imageToBufferedImage(Image image) {
|
||||||
|
if (image instanceof BufferedImage) {
|
||||||
|
return (BufferedImage) image;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a buffered image with the same dimensions and transparency as the original image
|
||||||
|
BufferedImage bufferedImage;
|
||||||
|
if (CHANNELS > 1) {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||||
|
} else {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw the original image onto the buffered image
|
||||||
|
Graphics2D g2d = bufferedImage.createGraphics();
|
||||||
|
g2d.drawImage(image, 0, 0, null);
|
||||||
|
g2d.dispose();
|
||||||
|
|
||||||
|
return bufferedImage;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import javax.swing.JPanel;
|
||||||
|
import javax.swing.JSplitPane;
|
||||||
|
import javax.swing.JLabel;
|
||||||
|
import java.awt.BorderLayout;
|
||||||
|
|
||||||
|
public class App2GUI extends JPanel {
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
private JPanel overall_panel;
|
||||||
|
private JPanel real_panel;
|
||||||
|
private JPanel gen_panel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the panel.
|
||||||
|
*/
|
||||||
|
public App2GUI() {
|
||||||
|
|
||||||
|
overall_panel = new JPanel();
|
||||||
|
add(overall_panel);
|
||||||
|
|
||||||
|
JSplitPane splitPane = new JSplitPane();
|
||||||
|
overall_panel.add(splitPane);
|
||||||
|
|
||||||
|
JPanel p1 = new JPanel();
|
||||||
|
splitPane.setLeftComponent(p1);
|
||||||
|
p1.setLayout(new BorderLayout(0, 0));
|
||||||
|
|
||||||
|
JLabel lblNewLabel = new JLabel("Generator");
|
||||||
|
p1.add(lblNewLabel, BorderLayout.NORTH);
|
||||||
|
|
||||||
|
gen_panel = new JPanel();
|
||||||
|
p1.add(gen_panel, BorderLayout.SOUTH);
|
||||||
|
|
||||||
|
JPanel p2 = new JPanel();
|
||||||
|
splitPane.setRightComponent(p2);
|
||||||
|
p2.setLayout(new BorderLayout(0, 0));
|
||||||
|
|
||||||
|
JLabel lblNewLabel_1 = new JLabel("Real");
|
||||||
|
p2.add(lblNewLabel_1, BorderLayout.NORTH);
|
||||||
|
|
||||||
|
real_panel = new JPanel();
|
||||||
|
p2.add(real_panel, BorderLayout.SOUTH);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public JPanel getOverall_panel() {
|
||||||
|
return overall_panel;
|
||||||
|
}
|
||||||
|
public JPanel getReal_panel() {
|
||||||
|
return real_panel;
|
||||||
|
}
|
||||||
|
public JPanel getGen_panel() {
|
||||||
|
return gen_panel;
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,22 +8,19 @@ ext {
|
||||||
javacppPlatform = osdetector.classifier
|
javacppPlatform = osdetector.classifier
|
||||||
}
|
}
|
||||||
|
|
||||||
def javacpp = [version: "1.5.9", presetsVersion: "1.5.9"]
|
def javacpp = [version: "1.5.7", presetsVersion: "1.5.7"]
|
||||||
def hdf5 = [version: "1.14.1"]
|
def hdf5 = [version: "1.12.1"]
|
||||||
def jackson = [version: "2.13.4"]
|
def jackson = [version: "2.13.4"]
|
||||||
def cuda = [version: "12.1"]
|
def cuda = [version: "11.6"]
|
||||||
def cudnn = [version: "8.9"]
|
def cudnn = [version: "8.3"]
|
||||||
def openblas = [version: "0.3.23"]
|
def openblas = [version: "0.3.19"]
|
||||||
def numpy = [version: "1.24.3"]
|
def numpy = [version: "1.22.2"]
|
||||||
def tensorflow_lite = [version: "2.12.0"]
|
|
||||||
def tensorflow = [version: "1.15.5"]
|
def tensorflow = [version: "1.15.5"]
|
||||||
def tensorrt = [version: "8.6.1.6"]
|
def cpython = [version: "3.10.2"]
|
||||||
def cpython = [version: "3.11.3"]
|
|
||||||
def mkl = [version:"2023.1"]
|
|
||||||
|
|
||||||
def javacv = [version:"1.5.9"]
|
def javacv = [version:"1.5.7"]
|
||||||
def opencv = [version: "4.7.0"]
|
def opencv = [version: "4.5.5"]
|
||||||
def leptonica = [version: "1.83.0"]
|
def leptonica = [version: "1.83.0"] //fix, only in javacpp 1.5.9
|
||||||
def junit = [version: "5.9.1"]
|
def junit = [version: "5.9.1"]
|
||||||
|
|
||||||
def flatbuffers = [version: "1.10.0"]
|
def flatbuffers = [version: "1.10.0"]
|
||||||
|
@ -44,6 +41,7 @@ dependencies {
|
||||||
|
|
||||||
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
||||||
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
||||||
|
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
||||||
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,6 +49,9 @@ dependencies {
|
||||||
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
||||||
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
||||||
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
||||||
|
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
||||||
|
//api "com.squareup.okhttp3:okhttp:${okhttp3}.version"
|
||||||
|
//api "com.squareup.okhttp3:logging-interceptor:${okhttp3}.version"
|
||||||
|
|
||||||
api 'com.google.guava:guava:30.1-jre'
|
api 'com.google.guava:guava:30.1-jre'
|
||||||
api "com.google.protobuf:protobuf-java:3.15.6"
|
api "com.google.protobuf:protobuf-java:3.15.6"
|
||||||
|
@ -58,6 +59,18 @@ dependencies {
|
||||||
api "com.google.protobuf:protobuf-java-util:3.15.6"
|
api "com.google.protobuf:protobuf-java-util:3.15.6"
|
||||||
api "com.google.flatbuffers:flatbuffers-java:${flatbuffers.version}"
|
api "com.google.flatbuffers:flatbuffers-java:${flatbuffers.version}"
|
||||||
|
|
||||||
|
/*
|
||||||
|
api "com.fasterxml.jackson.core:jackson-core:${jackson.version}"
|
||||||
|
api "com.fasterxml.jackson.core:jackson-databind:${jackson.version}"
|
||||||
|
api "com.fasterxml.jackson.core:jackson-annotations:${jackson.version}"
|
||||||
|
|
||||||
|
api "com.fasterxml.jackson.dataformat:jackson-dataformat-xml:${jackson.version}"
|
||||||
|
*/
|
||||||
|
// api "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${jackson.version}"
|
||||||
|
// api "com.fasterxml.jackson.datatype:jackson-datatype-joda:${jackson.version}"
|
||||||
|
// api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
|
||||||
|
|
||||||
|
|
||||||
api "org.projectlombok:lombok:1.18.28"
|
api "org.projectlombok:lombok:1.18.28"
|
||||||
|
|
||||||
/*Logging*/
|
/*Logging*/
|
||||||
|
@ -68,7 +81,7 @@ dependencies {
|
||||||
api "ch.qos.logback:logback-classic:1.2.3"
|
api "ch.qos.logback:logback-classic:1.2.3"
|
||||||
api 'ch.qos.logback:logback-core:1.2.3'
|
api 'ch.qos.logback:logback-core:1.2.3'
|
||||||
|
|
||||||
/* commons */
|
|
||||||
api 'commons-io:commons-io:2.5'
|
api 'commons-io:commons-io:2.5'
|
||||||
api 'commons-codec:commons-codec:1.11'
|
api 'commons-codec:commons-codec:1.11'
|
||||||
api 'commons-net:commons-net:3.6'
|
api 'commons-net:commons-net:3.6'
|
||||||
|
@ -105,22 +118,24 @@ dependencies {
|
||||||
api "org.bytedeco:javacv:${javacv.version}"
|
api "org.bytedeco:javacv:${javacv.version}"
|
||||||
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:leptonica-platform:${leptonica.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:leptonica-platform:${leptonica.version}-1.5.9"
|
||||||
|
api "org.bytedeco:leptonica:${leptonica.version}-1.5.9"
|
||||||
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
|
//api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:linux-x86_64"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:mkl:${mkl.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:tensorflow:${tensorflow.version}-1.5.8" //not available for javacpp 1.5.9 ?
|
api "org.bytedeco:mkl:2022.0-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:tensorflow-platform:${tensorflow.version}-1.5.8"
|
api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:tensorflow-lite:${tensorflow_lite.version}-${javacpp.presetsVersion}"
|
|
||||||
api "org.bytedeco:tensorflow-lite-platform:${tensorflow_lite.version}-${javacpp.presetsVersion}"
|
|
||||||
api "org.bytedeco:tensorrt:${tensorrt.version}-${javacpp.presetsVersion}"
|
|
||||||
api "org.bytedeco:tensorrt-platform:${tensorrt.version}-${javacpp.presetsVersion}"
|
|
||||||
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
|
//implementation "org.bytedeco:cpython-platform:3.9.6-1.5.6"
|
||||||
|
//implementation "org.bytedeco:numpy-platform:1.21.1-1.5.6"
|
||||||
|
|
||||||
/* Apache Spark */
|
/* Apache Spark */
|
||||||
api "org.apache.spark:spark-core_${scalaVersion}:${spark.version}"
|
api "org.apache.spark:spark-core_${scalaVersion}:${spark.version}"
|
||||||
|
@ -154,6 +169,16 @@ dependencies {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
publishing {
|
||||||
|
publications {
|
||||||
|
myPlatform(MavenPublication) {
|
||||||
|
from components.javaPlatform
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
tasks.withType(GenerateModuleMetadata).configureEach {
|
tasks.withType(GenerateModuleMetadata).configureEach {
|
||||||
// The value 'enforced-platform' is provided in the validation
|
// The value 'enforced-platform' is provided in the validation
|
||||||
// error message you got
|
// error message you got
|
||||||
|
|
|
@ -64,7 +64,7 @@ buildscript {
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id 'java-library'
|
id 'java-library'
|
||||||
id 'org.bytedeco.gradle-javacpp-build' version "1.5.9"
|
id 'org.bytedeco.gradle-javacpp-build' version "1.5.7"
|
||||||
id 'maven-publish'
|
id 'maven-publish'
|
||||||
id 'signing'
|
id 'signing'
|
||||||
}
|
}
|
||||||
|
@ -336,11 +336,12 @@ chipList.each { thisChip ->
|
||||||
&& !project.getProperty("skip-native").equals("true") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
|
&& !project.getProperty("skip-native").equals("true") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
|
||||||
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute()
|
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute()
|
||||||
def outp = proc.text
|
def outp = proc.text
|
||||||
def cl = outp.replace("\\", "\\\\").trim()
|
def cl = "\"" + outp.replace("\\", "\\\\").trim() + "\""
|
||||||
def currentCompiler = ""
|
def currentCompiler = ""
|
||||||
doFirst{
|
doFirst{
|
||||||
currentCompiler = System.getProperty("org.bytedeco.javacpp.platform.compiler")
|
currentCompiler = System.getProperty("org.bytedeco.javacpp.platform.compiler")
|
||||||
System.setProperty("org.bytedeco.javacpp.platform.compiler", cl)
|
System.setProperty("org.bytedeco.javacpp.platform.compiler", cl)
|
||||||
|
System.setProperty("platform.compiler.cpp11", cl)
|
||||||
logger.quiet("Task ${thisTask.name} overrides compiler '${currentCompiler}' with '${cl}'.")
|
logger.quiet("Task ${thisTask.name} overrides compiler '${currentCompiler}' with '${cl}'.")
|
||||||
}
|
}
|
||||||
doLast {
|
doLast {
|
||||||
|
|
|
@ -102,16 +102,18 @@ ENDIF()
|
||||||
|
|
||||||
IF(${SD_EXTENSION} MATCHES "avx2")
|
IF(${SD_EXTENSION} MATCHES "avx2")
|
||||||
message("Extension AVX2 enabled.")
|
message("Extension AVX2 enabled.")
|
||||||
set(ARCH_TUNE "${ARCH_TUNE} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 -DSD_F16C=true -DF_AVX2=true")
|
#-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1
|
||||||
|
set(ARCH_TUNE "${ARCH_TUNE} -DSD_F16C=true -DF_AVX2=true")
|
||||||
ELSEIF(${SD_EXTENSION} MATCHES "avx512")
|
ELSEIF(${SD_EXTENSION} MATCHES "avx512")
|
||||||
message("Extension AVX512 enabled.")
|
message("Extension AVX512 enabled.")
|
||||||
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
|
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
|
#-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves
|
||||||
|
set(ARCH_TUNE "${ARCH_TUNE} -DSD_F16C=true -DF_AVX512=true")
|
||||||
ENDIF()
|
ENDIF()
|
||||||
|
|
||||||
if (NOT WIN32)
|
if (NOT WIN32)
|
||||||
# we don't want this definition for msvc
|
# we don't want this definition for msvc
|
||||||
set(ARCH_TUNE "-march=${SD_ARCH} -mtune=${ARCH_TYPE}")
|
set(ARCH_TUNE "${ARCH_TUNE} -march=${SD_ARCH} -mtune=${ARCH_TYPE}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD)
|
||||||
|
|
|
@ -87,7 +87,7 @@ ext {
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
|
cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "cuda"
|
cudaTestRuntime group: "org.bytedeco", name: "cuda"
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
||||||
cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
|
//cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
|
||||||
cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements"))
|
cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements"))
|
||||||
/*
|
/*
|
||||||
cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {
|
cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {
|
||||||
|
|
Loading…
Reference in New Issue