Compare commits
	
		
			20 Commits
		
	
	
		
			master
			...
			enhance-bu
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| dd151aec3f | |||
| 1b3338f809 | |||
| 3d949c5348 | |||
| 6930116c18 | |||
| e27fb8422f | |||
| d0342fc939 | |||
| b34b96d929 | |||
| 8f51471a31 | |||
| dc5de40620 | |||
| e834407b6e | |||
| 4dc5a116b6 | |||
| 997143b9dd | |||
| 0bed17c97f | |||
| 8d73a7a410 | |||
| c758cf918f | |||
| 2c8c6d9624 | |||
| 0ba049885f | |||
| 345f55a003 | |||
| 1c39dbee52 | |||
| ea504bff41 | 
@ -1,4 +1,4 @@
 | 
			
		||||
FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
 | 
			
		||||
FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04
 | 
			
		||||
 | 
			
		||||
RUN apt-get update &&  \
 | 
			
		||||
    DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
 | 
			
		||||
@ -11,10 +11,5 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.
 | 
			
		||||
    rm cmake-3.24.2-linux-x86_64.sh
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
RUN echo "/usr/local/cuda/compat/" >> /etc/ld.so.conf.d/cuda-driver.conf
 | 
			
		||||
 | 
			
		||||
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
 | 
			
		||||
 | 
			
		||||
RUN ldconfig -p | grep cuda
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ pipeline {
 | 
			
		||||
        }*/
 | 
			
		||||
        stage('publish-linux-cpu') {
 | 
			
		||||
            environment {
 | 
			
		||||
                MAVEN = credentials('Internal_Archiva')
 | 
			
		||||
                MAVEN = credentials('Internal Archiva')
 | 
			
		||||
                OSSRH = credentials('OSSRH')
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@ -79,9 +79,4 @@ pipeline {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    post {
 | 
			
		||||
        always {
 | 
			
		||||
            junit '**/build/test-results/**/*.xml'
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -59,10 +59,4 @@ pipeline {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    post {
 | 
			
		||||
        always {
 | 
			
		||||
            junit '**/build/test-results/**/*.xml'
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -85,9 +85,4 @@ pipeline {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    post {
 | 
			
		||||
        always {
 | 
			
		||||
            junit '**/build/test-results/**/*.xml'
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ pipeline {
 | 
			
		||||
            dir '.docker'
 | 
			
		||||
            label 'linux && docker && cuda'
 | 
			
		||||
            //additionalBuildArgs  '--build-arg version=1.0.2'
 | 
			
		||||
            args '--gpus all' //needed for test only, you can build without GPU
 | 
			
		||||
            //args '--gpus all' --needed for test only, you can build without GPU
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -72,10 +72,4 @@ pipeline {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    post {
 | 
			
		||||
        always {
 | 
			
		||||
            junit '**/build/test-results/**/*.xml'
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -36,9 +36,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
 | 
			
		||||
public class LoadBackendTests {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void loadBackend() throws NoSuchFieldException, IllegalAccessException {
 | 
			
		||||
    public void loadBackend() throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
 | 
			
		||||
        // check if Nd4j is there
 | 
			
		||||
        Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
 | 
			
		||||
        //Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
 | 
			
		||||
        final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
 | 
			
		||||
        sysPathsField.setAccessible(true);
 | 
			
		||||
        sysPathsField.set(null, null);
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,6 @@ import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
			
		||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
			
		||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
			
		||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
			
		||||
import org.junit.jupiter.api.Tag;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
import org.nd4j.linalg.activations.Activation;
 | 
			
		||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
			
		||||
@ -123,7 +122,7 @@ public class App {
 | 
			
		||||
        return conf;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test @Tag("long-running")
 | 
			
		||||
    @Test
 | 
			
		||||
    public void runTest() throws Exception {
 | 
			
		||||
        App.main(null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -37,8 +37,6 @@ import org.datavec.image.loader.NativeImageLoader;
 | 
			
		||||
import org.datavec.image.recordreader.ImageRecordReader;
 | 
			
		||||
import org.datavec.image.transform.*;
 | 
			
		||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
			
		||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
			
		||||
@ -46,29 +44,25 @@ import org.deeplearning4j.nn.conf.layers.*;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
			
		||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
			
		||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
			
		||||
import org.junit.jupiter.api.Tag;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
import org.nd4j.evaluation.classification.Evaluation;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.DataSet;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static net.brutex.gan.App2Config.BATCHSIZE;
 | 
			
		||||
 | 
			
		||||
@Slf4j
 | 
			
		||||
public class App2 {
 | 
			
		||||
 | 
			
		||||
    final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
 | 
			
		||||
 | 
			
		||||
    static final float COLORSPACE = 255f;
 | 
			
		||||
    static final int DIMENSIONS = 28;
 | 
			
		||||
    static final int CHANNELS = 1;
 | 
			
		||||
    final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
 | 
			
		||||
 | 
			
		||||
    final int OUTPUT_PER_PANEL = 10;
 | 
			
		||||
 | 
			
		||||
    final boolean BIAS = true;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    static final int BATCHSIZE=128;
 | 
			
		||||
 | 
			
		||||
    private JFrame frame2, frame;
 | 
			
		||||
    static final String OUTPUT_DIR = "d:/out/";
 | 
			
		||||
@ -76,12 +70,12 @@ public class App2 {
 | 
			
		||||
    final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
 | 
			
		||||
    final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
 | 
			
		||||
 | 
			
		||||
    @Test @Tag("long-running")
 | 
			
		||||
    @Test
 | 
			
		||||
    void runTest() throws IOException {
 | 
			
		||||
        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
			
		||||
 | 
			
		||||
        MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
 | 
			
		||||
        FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans3"), NativeImageLoader.getALLOWED_FORMATS());
 | 
			
		||||
        FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
 | 
			
		||||
        ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
 | 
			
		||||
        ImageTransform transform2 = new ShowImageTransform("Tester", 30);
 | 
			
		||||
        ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
 | 
			
		||||
@ -134,94 +128,12 @@ public class App2 {
 | 
			
		||||
 | 
			
		||||
        log.info("Generator Summary:\n{}", gen.summary());
 | 
			
		||||
        log.info("GAN Summary:\n{}", gan.summary());
 | 
			
		||||
        dis.addTrainingListeners(new PerformanceListener(3, true, "DIS"));
 | 
			
		||||
        //gen.addTrainingListeners(new PerformanceListener(3, true, "GEN")); //is never trained separately from GAN
 | 
			
		||||
        gan.addTrainingListeners(new PerformanceListener(3, true, "GAN"));
 | 
			
		||||
/*
 | 
			
		||||
    Thread vt =
 | 
			
		||||
        new Thread(
 | 
			
		||||
            new Runnable() {
 | 
			
		||||
              @Override
 | 
			
		||||
              public void run() {
 | 
			
		||||
                while (true) {
 | 
			
		||||
                  visualize(0, 0, gen);
 | 
			
		||||
                    try {
 | 
			
		||||
                        Thread.sleep(10000);
 | 
			
		||||
                    } catch (InterruptedException e) {
 | 
			
		||||
                        throw new RuntimeException(e);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
              }
 | 
			
		||||
            });
 | 
			
		||||
        vt.start();
 | 
			
		||||
*/
 | 
			
		||||
        dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
			
		||||
        gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
 | 
			
		||||
        gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
			
		||||
 | 
			
		||||
        App2Display display = new App2Display();
 | 
			
		||||
        //Repack training data with new fake/real label. Original MNist has 10 labels, one for each digit
 | 
			
		||||
        DataSet data = null;
 | 
			
		||||
        int j =0;
 | 
			
		||||
        for(int i=0;i<App2Config.EPOCHS;i++) {
 | 
			
		||||
            log.info("Epoch {}", i);
 | 
			
		||||
            data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
 | 
			
		||||
            while (trainData.hasNext()) {
 | 
			
		||||
                j++;
 | 
			
		||||
                INDArray real = trainData.next().getFeatures();
 | 
			
		||||
                INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
 | 
			
		||||
 | 
			
		||||
                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1,
 | 
			
		||||
                        Nd4j.rand(BATCHSIZE, App2Config.INPUT));
 | 
			
		||||
                //sigmoid output is -1 to 1
 | 
			
		||||
                fake.addi(1f).divi(2f);
 | 
			
		||||
 | 
			
		||||
                if (j % 50 == 1) {
 | 
			
		||||
                  display.visualize(new INDArray[] {fake}, App2Config.OUTPUT_PER_PANEL, false);
 | 
			
		||||
                  display.visualize(new INDArray[] {real}, App2Config.OUTPUT_PER_PANEL, true);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                DataSet realSet = new DataSet(real, label_real);
 | 
			
		||||
                DataSet fakeSet = new DataSet(fake, label_fake);
 | 
			
		||||
 | 
			
		||||
                //start next round if there are not enough images left to have a full batchsize dataset
 | 
			
		||||
                if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
 | 
			
		||||
                    log.warn("Your total number of input images is not a multiple of {}, "
 | 
			
		||||
                            + "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
                //if(real.length()/BATCHSIZE!=784) break;
 | 
			
		||||
                data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
 | 
			
		||||
 | 
			
		||||
            }
 | 
			
		||||
            //fit the discriminator
 | 
			
		||||
            dis.fit(data);
 | 
			
		||||
            dis.fit(data);
 | 
			
		||||
            // Update the discriminator in the GAN network
 | 
			
		||||
            updateGan(gen, dis, gan);
 | 
			
		||||
 | 
			
		||||
            //reset the training data and fit the complete GAN
 | 
			
		||||
            if (trainData.resetSupported()) {
 | 
			
		||||
                trainData.reset();
 | 
			
		||||
            } else {
 | 
			
		||||
                log.error("Trainingdata {} does not support reset.", trainData.toString());
 | 
			
		||||
            }
 | 
			
		||||
            gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_real));
 | 
			
		||||
 | 
			
		||||
            if (trainData.resetSupported()) {
 | 
			
		||||
                trainData.reset();
 | 
			
		||||
            } else {
 | 
			
		||||
                log.error("Trainingdata {} does not support reset.", trainData.toString());
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            log.info("Updated GAN's generator from gen.");
 | 
			
		||||
            updateGen(gen, gan);
 | 
			
		||||
            gen.save(new File("mnist-mlp-generator.dlj"));
 | 
			
		||||
        }
 | 
			
		||||
        //vt.stop();
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
        int j;
 | 
			
		||||
        for (int i = 0; i < App2Config.EPOCHS; i++) { //epoch
 | 
			
		||||
            j=0;
 | 
			
		||||
        int j = 0;
 | 
			
		||||
        for (int i = 0; i < 51; i++) { //epoch
 | 
			
		||||
            while (trainData.hasNext()) {
 | 
			
		||||
                j++;
 | 
			
		||||
                DataSet next = trainData.next();
 | 
			
		||||
@ -299,8 +211,6 @@ public class App2 {
 | 
			
		||||
            log.info("Updated GAN's generator from gen.");
 | 
			
		||||
            gen.save(new File("mnist-mlp-generator.dlj"));
 | 
			
		||||
        }
 | 
			
		||||
        */
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -313,11 +223,110 @@ public class App2 {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
 | 
			
		||||
        if (isOrig) {
 | 
			
		||||
            frame.setTitle("Viz Original");
 | 
			
		||||
        } else {
 | 
			
		||||
            frame.setTitle("Generated");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
			
		||||
        frame.setLayout(new BorderLayout());
 | 
			
		||||
 | 
			
		||||
        JPanel panelx = new JPanel();
 | 
			
		||||
 | 
			
		||||
        panelx.setLayout(new GridLayout(4, 4, 8, 8));
 | 
			
		||||
        for (INDArray sample : samples) {
 | 
			
		||||
            for(int i = 0; i<batchElements; i++) {
 | 
			
		||||
                panelx.add(getImage(sample, i, isOrig));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        frame.add(panelx, BorderLayout.CENTER);
 | 
			
		||||
        frame.setVisible(true);
 | 
			
		||||
 | 
			
		||||
        frame.revalidate();
 | 
			
		||||
        frame.setMinimumSize(new Dimension(300, 20));
 | 
			
		||||
        frame.pack();
 | 
			
		||||
        return frame;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
			
		||||
        final BufferedImage bi;
 | 
			
		||||
        if(CHANNELS >1) {
 | 
			
		||||
            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
 | 
			
		||||
        } else {
 | 
			
		||||
            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
 | 
			
		||||
        }
 | 
			
		||||
        final int imageSize = DIMENSIONS * DIMENSIONS;
 | 
			
		||||
        final int offset = batchElement * imageSize;
 | 
			
		||||
        int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
			
		||||
 | 
			
		||||
        //Image in NCHW - channels first format
 | 
			
		||||
        for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
			
		||||
            for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
 | 
			
		||||
                for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
 | 
			
		||||
                    float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
 | 
			
		||||
                    if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
 | 
			
		||||
                    bi.getRaster().setSample(x, y, c, f_pxl);
 | 
			
		||||
                    pxl++; //next item in INDArray
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        ImageIcon orig = new ImageIcon(bi);
 | 
			
		||||
        Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
 | 
			
		||||
        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
			
		||||
        if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
			
		||||
        return new JLabel(scaled);
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
			
		||||
        String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
			
		||||
 | 
			
		||||
        try {
 | 
			
		||||
            // Save the images to disk
 | 
			
		||||
            saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
			
		||||
 | 
			
		||||
            log.debug("Images saved successfully.");
 | 
			
		||||
        } catch (IOException e) {
 | 
			
		||||
            log.error("Error saving the images: {}", e.getMessage());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
			
		||||
        File directory = new File(outputDirectory);
 | 
			
		||||
        if (!directory.exists()) {
 | 
			
		||||
            directory.mkdir();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        File outputFile = new File(directory, fileName);
 | 
			
		||||
        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static BufferedImage imageToBufferedImage(Image image) {
 | 
			
		||||
        if (image instanceof BufferedImage) {
 | 
			
		||||
            return (BufferedImage) image;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    // Create a buffered image with the same dimensions and transparency as the original image
 | 
			
		||||
        BufferedImage bufferedImage;
 | 
			
		||||
    if (CHANNELS > 1) {
 | 
			
		||||
      bufferedImage =
 | 
			
		||||
          new BufferedImage(
 | 
			
		||||
              image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
			
		||||
        } else {
 | 
			
		||||
        bufferedImage =
 | 
			
		||||
                new BufferedImage(
 | 
			
		||||
                        image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Draw the original image onto the buffered image
 | 
			
		||||
        Graphics2D g2d = bufferedImage.createGraphics();
 | 
			
		||||
        g2d.drawImage(image, 0, 0, null);
 | 
			
		||||
        g2d.dispose();
 | 
			
		||||
 | 
			
		||||
        return bufferedImage;
 | 
			
		||||
    }
 | 
			
		||||
    private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
			
		||||
        for (int i = 0; i < gen.getLayers().length; i++) {
 | 
			
		||||
            gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
			
		||||
@ -331,41 +340,4 @@ public class App2 {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testDiskriminator() throws IOException {
 | 
			
		||||
        MultiLayerNetwork net = new MultiLayerNetwork(App2Config.discriminator());
 | 
			
		||||
        net.init();
 | 
			
		||||
        net.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
			
		||||
        DataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
 | 
			
		||||
 | 
			
		||||
        DataSet data = null;
 | 
			
		||||
        for(int i=0;i<App2Config.EPOCHS;i++) {
 | 
			
		||||
            log.info("Epoch {}", i);
 | 
			
		||||
            data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
 | 
			
		||||
            while (trainData.hasNext()) {
 | 
			
		||||
                INDArray real = trainData.next().getFeatures();
 | 
			
		||||
                long[] l = new long[]{BATCHSIZE, real.length() / BATCHSIZE};
 | 
			
		||||
                INDArray fake = Nd4j.rand(l );
 | 
			
		||||
 | 
			
		||||
                DataSet realSet = new DataSet(real, label_real);
 | 
			
		||||
                DataSet fakeSet = new DataSet(fake, label_fake);
 | 
			
		||||
                if(real.length()/BATCHSIZE!=784) break;
 | 
			
		||||
                data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
 | 
			
		||||
 | 
			
		||||
            }
 | 
			
		||||
            net.fit(data);
 | 
			
		||||
            trainData.reset();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        long[] l = new long[]{BATCHSIZE, 784};
 | 
			
		||||
        INDArray fake = Nd4j.rand(l );
 | 
			
		||||
        DataSet fakeSet = new DataSet(fake, label_fake);
 | 
			
		||||
        data = DataSet.merge(Arrays.asList(data, fakeSet));
 | 
			
		||||
        ExistingDataSetIterator iter = new ExistingDataSetIterator(data);
 | 
			
		||||
        Evaluation eval = net.evaluate(iter);
 | 
			
		||||
        log.info( "\n" + eval.confusionMatrix());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -36,17 +36,10 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
			
		||||
public class App2Config {
 | 
			
		||||
 | 
			
		||||
  public static final int INPUT = 100;
 | 
			
		||||
  public static final int BATCHSIZE=150;
 | 
			
		||||
  public static final int X_DIM = 28;
 | 
			
		||||
  public static final int Y_DIM = 28;
 | 
			
		||||
  public static final int y_DIM = 28;
 | 
			
		||||
  public static final int CHANNELS = 1;
 | 
			
		||||
  public static final int EPOCHS = 50;
 | 
			
		||||
  public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
 | 
			
		||||
  public static final IUpdater UPDATER_DIS = Adam.builder().learningRate(0.02).beta1(0.5).build();
 | 
			
		||||
  public static final boolean SHOW_GENERATED = true;
 | 
			
		||||
  public static final float COLORSPACE = 255f;
 | 
			
		||||
 | 
			
		||||
  final static int OUTPUT_PER_PANEL = 10;
 | 
			
		||||
 | 
			
		||||
  static LayerConfiguration[] genLayerConfig() {
 | 
			
		||||
    return new LayerConfiguration[] {
 | 
			
		||||
@ -165,7 +158,7 @@ public class App2Config {
 | 
			
		||||
                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
                    .gradientNormalizationThreshold(100)
 | 
			
		||||
                    .seed(42)
 | 
			
		||||
                    .updater(UPDATER_DIS)
 | 
			
		||||
                    .updater(UPDATER)
 | 
			
		||||
                    .weightInit(WeightInit.XAVIER)
 | 
			
		||||
                    // .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
			
		||||
                    .weightNoise(null)
 | 
			
		||||
 | 
			
		||||
@ -1,160 +0,0 @@
 | 
			
		||||
/*
 | 
			
		||||
 *
 | 
			
		||||
 *    ******************************************************************************
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * This program and the accompanying materials are made available under the
 | 
			
		||||
 *    * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    *  See the NOTICE file distributed with this work for additional
 | 
			
		||||
 *    *  information regarding copyright ownership.
 | 
			
		||||
 *    * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 *    * License for the specific language governing permissions and limitations
 | 
			
		||||
 *    * under the License.
 | 
			
		||||
 *    *
 | 
			
		||||
 *    * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 *    *****************************************************************************
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package net.brutex.gan;
 | 
			
		||||
 | 
			
		||||
import com.google.inject.Singleton;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
import javax.imageio.ImageIO;
 | 
			
		||||
import javax.swing.*;
 | 
			
		||||
import java.awt.*;
 | 
			
		||||
import java.awt.color.ColorSpace;
 | 
			
		||||
import java.awt.image.BufferedImage;
 | 
			
		||||
import java.io.File;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.util.UUID;
 | 
			
		||||
 | 
			
		||||
import static net.brutex.gan.App2.OUTPUT_DIR;
 | 
			
		||||
import static net.brutex.gan.App2Config.*;
 | 
			
		||||
@Slf4j
 | 
			
		||||
@Singleton
 | 
			
		||||
public class App2Display {
 | 
			
		||||
 | 
			
		||||
    private final JFrame frame = new JFrame();
 | 
			
		||||
  private final App2GUI display = new App2GUI();
 | 
			
		||||
 | 
			
		||||
  private final JPanel real_panel;
 | 
			
		||||
  private final JPanel fake_panel;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public App2Display() {
 | 
			
		||||
        frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
			
		||||
        frame.setContentPane(display.getOverall_panel());
 | 
			
		||||
        frame.setMinimumSize(new Dimension(300, 20));
 | 
			
		||||
        frame.pack();
 | 
			
		||||
        frame.setVisible(true);
 | 
			
		||||
        real_panel = display.getReal_panel();
 | 
			
		||||
        fake_panel = display.getGen_panel();
 | 
			
		||||
        real_panel.setLayout(new GridLayout(4, 4, 8, 8));
 | 
			
		||||
        fake_panel.setLayout(new GridLayout(4, 4, 8, 8));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void visualize(INDArray[] samples, int batchElements, boolean isOrig) {
 | 
			
		||||
        for (INDArray sample : samples) {
 | 
			
		||||
            for(int i = 0; i<batchElements; i++) {
 | 
			
		||||
                final Image img = this.getImage(sample, i, isOrig);
 | 
			
		||||
                final ImageIcon icon = new ImageIcon(img);
 | 
			
		||||
                if(isOrig) {
 | 
			
		||||
                    if(real_panel.getComponents().length>=OUTPUT_PER_PANEL) {
 | 
			
		||||
                        real_panel.remove(0);
 | 
			
		||||
                    }
 | 
			
		||||
                    real_panel.add(new JLabel(icon));
 | 
			
		||||
                } else {
 | 
			
		||||
                        if(fake_panel.getComponents().length>=OUTPUT_PER_PANEL) {
 | 
			
		||||
                            fake_panel.remove(0);
 | 
			
		||||
                        }
 | 
			
		||||
                    fake_panel.add(new JLabel(icon));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        frame.pack();
 | 
			
		||||
        frame.repaint();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Image getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
			
		||||
        final BufferedImage bi;
 | 
			
		||||
        if(CHANNELS >1) {
 | 
			
		||||
            bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
 | 
			
		||||
        } else {
 | 
			
		||||
            bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
 | 
			
		||||
        }
 | 
			
		||||
        final int imageSize = X_DIM * Y_DIM;
 | 
			
		||||
        final int offset = batchElement * imageSize;
 | 
			
		||||
        int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
			
		||||
 | 
			
		||||
        //Image in NCHW - channels first format
 | 
			
		||||
        for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
			
		||||
            for (int y = 0; y < X_DIM; y++) { // step through the columns x
 | 
			
		||||
                for (int x = 0; x < Y_DIM; x++) { //step through the rows y
 | 
			
		||||
                    float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
 | 
			
		||||
                    if(isOrig) log.trace("'{}.'{} Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, isOrig ? "Real" : "Fake", x, y, c, pxl, f_pxl);
 | 
			
		||||
                    bi.getRaster().setSample(x, y, c, f_pxl);
 | 
			
		||||
                    pxl++; //next item in INDArray
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        ImageIcon orig = new ImageIcon(bi);
 | 
			
		||||
        Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
			
		||||
        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
			
		||||
        //if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
			
		||||
        return imageScaled;
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
			
		||||
        String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
			
		||||
 | 
			
		||||
        try {
 | 
			
		||||
            // Save the images to disk
 | 
			
		||||
            saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
			
		||||
 | 
			
		||||
            log.debug("Images saved successfully.");
 | 
			
		||||
        } catch (IOException e) {
 | 
			
		||||
            log.error("Error saving the images: {}", e.getMessage());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
			
		||||
        File directory = new File(outputDirectory);
 | 
			
		||||
        if (!directory.exists()) {
 | 
			
		||||
            directory.mkdir();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        File outputFile = new File(directory, fileName);
 | 
			
		||||
        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static BufferedImage imageToBufferedImage(Image image) {
 | 
			
		||||
        if (image instanceof BufferedImage) {
 | 
			
		||||
            return (BufferedImage) image;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Create a buffered image with the same dimensions and transparency as the original image
 | 
			
		||||
        BufferedImage bufferedImage;
 | 
			
		||||
        if (CHANNELS > 1) {
 | 
			
		||||
            bufferedImage =
 | 
			
		||||
                    new BufferedImage(
 | 
			
		||||
                            image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
			
		||||
        } else {
 | 
			
		||||
            bufferedImage =
 | 
			
		||||
                    new BufferedImage(
 | 
			
		||||
                            image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Draw the original image onto the buffered image
 | 
			
		||||
        Graphics2D g2d = bufferedImage.createGraphics();
 | 
			
		||||
        g2d.drawImage(image, 0, 0, null);
 | 
			
		||||
        g2d.dispose();
 | 
			
		||||
 | 
			
		||||
        return bufferedImage;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -1,61 +0,0 @@
 | 
			
		||||
package net.brutex.gan;
 | 
			
		||||
 | 
			
		||||
import javax.swing.JPanel;
 | 
			
		||||
import javax.swing.JSplitPane;
 | 
			
		||||
import javax.swing.JLabel;
 | 
			
		||||
import java.awt.BorderLayout;
 | 
			
		||||
 | 
			
		||||
public class App2GUI extends JPanel {
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * 
 | 
			
		||||
	 */
 | 
			
		||||
	private static final long serialVersionUID = 1L;
 | 
			
		||||
	private JPanel overall_panel;
 | 
			
		||||
	private JPanel real_panel;
 | 
			
		||||
	private JPanel gen_panel;
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * Create the panel.
 | 
			
		||||
	 */
 | 
			
		||||
	public App2GUI() {
 | 
			
		||||
		
 | 
			
		||||
		overall_panel = new JPanel();
 | 
			
		||||
		add(overall_panel);
 | 
			
		||||
		
 | 
			
		||||
		JSplitPane splitPane = new JSplitPane();
 | 
			
		||||
		overall_panel.add(splitPane);
 | 
			
		||||
		
 | 
			
		||||
		JPanel p1 = new JPanel();
 | 
			
		||||
		splitPane.setLeftComponent(p1);
 | 
			
		||||
		p1.setLayout(new BorderLayout(0, 0));
 | 
			
		||||
		
 | 
			
		||||
		JLabel lblNewLabel = new JLabel("Generator");
 | 
			
		||||
		p1.add(lblNewLabel, BorderLayout.NORTH);
 | 
			
		||||
		
 | 
			
		||||
		gen_panel = new JPanel();
 | 
			
		||||
		p1.add(gen_panel, BorderLayout.SOUTH);
 | 
			
		||||
		
 | 
			
		||||
		JPanel p2 = new JPanel();
 | 
			
		||||
		splitPane.setRightComponent(p2);
 | 
			
		||||
		p2.setLayout(new BorderLayout(0, 0));
 | 
			
		||||
		
 | 
			
		||||
		JLabel lblNewLabel_1 = new JLabel("Real");
 | 
			
		||||
		p2.add(lblNewLabel_1, BorderLayout.NORTH);
 | 
			
		||||
		
 | 
			
		||||
		real_panel = new JPanel();
 | 
			
		||||
		p2.add(real_panel, BorderLayout.SOUTH);
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
	public JPanel getOverall_panel() {
 | 
			
		||||
		return overall_panel;
 | 
			
		||||
	}
 | 
			
		||||
	public JPanel getReal_panel() {
 | 
			
		||||
		return real_panel;
 | 
			
		||||
	}
 | 
			
		||||
	public JPanel getGen_panel() {
 | 
			
		||||
		return gen_panel;
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -31,7 +31,6 @@ import org.deeplearning4j.nn.conf.layers.DropoutLayer;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
			
		||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
			
		||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
			
		||||
import org.junit.jupiter.api.Tag;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
import org.nd4j.linalg.activations.Activation;
 | 
			
		||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
			
		||||
@ -101,7 +100,7 @@ public class MnistSimpleGAN {
 | 
			
		||||
 | 
			
		||||
    return new MultiLayerNetwork(discConf);
 | 
			
		||||
  }
 | 
			
		||||
  @Test @Tag("long-running")
 | 
			
		||||
  @Test
 | 
			
		||||
  public void runTest() throws Exception {
 | 
			
		||||
    main(null);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ ext {
 | 
			
		||||
 | 
			
		||||
    def javacv = [version:"1.5.7"]
 | 
			
		||||
    def opencv = [version: "4.5.5"]
 | 
			
		||||
    def leptonica = [version: "1.83.0"] //fix, only in javacpp 1.5.9
 | 
			
		||||
    def leptonica = [version: "1.82.0"]
 | 
			
		||||
    def junit = [version: "5.9.1"]
 | 
			
		||||
 | 
			
		||||
    def flatbuffers = [version: "1.10.0"]
 | 
			
		||||
@ -118,8 +118,7 @@ dependencies {
 | 
			
		||||
        api "org.bytedeco:javacv:${javacv.version}"
 | 
			
		||||
        api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:leptonica-platform:${leptonica.version}-1.5.9"
 | 
			
		||||
        api "org.bytedeco:leptonica:${leptonica.version}-1.5.9"
 | 
			
		||||
        api "org.bytedeco:leptonica-platform:${leptonica.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
 | 
			
		||||
@ -130,7 +129,6 @@ dependencies {
 | 
			
		||||
        api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:mkl:2022.0-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}"
 | 
			
		||||
        api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
 | 
			
		||||
        api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"
 | 
			
		||||
 | 
			
		||||
@ -28,8 +28,7 @@ dependencies {
 | 
			
		||||
    implementation "org.bytedeco:javacv"
 | 
			
		||||
    implementation "org.bytedeco:opencv"
 | 
			
		||||
    implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget
 | 
			
		||||
    //implementation "org.bytedeco:leptonica-platform"
 | 
			
		||||
    implementation group: "org.bytedeco", name: "leptonica", classifier: buildTarget
 | 
			
		||||
    implementation "org.bytedeco:leptonica-platform"
 | 
			
		||||
    implementation "org.bytedeco:hdf5-platform"
 | 
			
		||||
 | 
			
		||||
    implementation "commons-io:commons-io"
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ import java.nio.ByteOrder;
 | 
			
		||||
import org.bytedeco.leptonica.*;
 | 
			
		||||
import org.bytedeco.opencv.opencv_core.*;
 | 
			
		||||
 | 
			
		||||
import static org.bytedeco.leptonica.global.leptonica.*;
 | 
			
		||||
import static org.bytedeco.leptonica.global.lept.*;
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_core.*;
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
 | 
			
		||||
 | 
			
		||||
@ -52,9 +52,10 @@ import java.io.InputStream;
 | 
			
		||||
import java.lang.reflect.Field;
 | 
			
		||||
import java.nio.file.Path;
 | 
			
		||||
import java.util.Random;
 | 
			
		||||
import java.util.stream.IntStream;
 | 
			
		||||
import java.util.stream.Stream;
 | 
			
		||||
 | 
			
		||||
import static org.bytedeco.leptonica.global.leptonica.*;
 | 
			
		||||
import static org.bytedeco.leptonica.global.lept.*;
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_core.*;
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.*;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -78,7 +78,7 @@ class dnnTest {
 | 
			
		||||
     * DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
 | 
			
		||||
     */
 | 
			
		||||
    NeuralNetConfiguration network =
 | 
			
		||||
        NN.nn()
 | 
			
		||||
        NN.net()
 | 
			
		||||
            .seed(42)
 | 
			
		||||
            .updater(Adam.builder().learningRate(0.0002).beta1(0.5).build())
 | 
			
		||||
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ dependencies {
 | 
			
		||||
    implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatasets
 | 
			
		||||
    implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatavecIterators
 | 
			
		||||
    implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
 | 
			
		||||
    implementation "org.apache.hadoop:hadoop-common:3.2.4"
 | 
			
		||||
    implementation "org.apache.hadoop:hadoop-common:3.2.0"
 | 
			
		||||
    implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml"
 | 
			
		||||
    implementation projects.cavisDatavec.cavisDatavecApi
 | 
			
		||||
    implementation projects.cavisDatavec.cavisDatavecSpark.cavisDatavecSparkCore
 | 
			
		||||
 | 
			
		||||
@ -52,8 +52,8 @@ buildscript {
 | 
			
		||||
        classpath platform(project(":cavis-common-platform"))
 | 
			
		||||
        classpath group: "org.bytedeco", name: "openblas"
 | 
			
		||||
        classpath group: "org.bytedeco", name: "openblas", classifier: "${javacppPlatform}"
 | 
			
		||||
        classpath group: "org.bytedeco", name:"mkl"
 | 
			
		||||
        classpath group: "org.bytedeco", name:"mkl", classifier: "${javacppPlatform}"
 | 
			
		||||
        classpath group: "org.bytedeco", name:"mkl-dnn"
 | 
			
		||||
        classpath group: "org.bytedeco", name:"mkl-dnn", classifier: "${javacppPlatform}"
 | 
			
		||||
        classpath group: "org.bytedeco", name: "javacpp"
 | 
			
		||||
        classpath group: "org.bytedeco", name: "javacpp", classifier: "${javacppPlatform}"
 | 
			
		||||
    }
 | 
			
		||||
@ -64,7 +64,7 @@ buildscript {
 | 
			
		||||
 | 
			
		||||
plugins {
 | 
			
		||||
    id 'java-library'
 | 
			
		||||
    id 'org.bytedeco.gradle-javacpp-build' version "1.5.7"
 | 
			
		||||
    id 'org.bytedeco.gradle-javacpp-build' version "1.5.9"
 | 
			
		||||
    id 'maven-publish'
 | 
			
		||||
    id 'signing'
 | 
			
		||||
}
 | 
			
		||||
@ -336,12 +336,11 @@ chipList.each { thisChip ->
 | 
			
		||||
                && !project.getProperty("skip-native").equals("true") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
 | 
			
		||||
            def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute()
 | 
			
		||||
            def outp = proc.text
 | 
			
		||||
            def cl = "\"" + outp.replace("\\", "\\\\").trim() + "\""
 | 
			
		||||
            def cl = outp.replace("\\", "\\\\").trim()
 | 
			
		||||
            def currentCompiler = ""
 | 
			
		||||
            doFirst{
 | 
			
		||||
                currentCompiler = System.getProperty("org.bytedeco.javacpp.platform.compiler")
 | 
			
		||||
                System.setProperty("org.bytedeco.javacpp.platform.compiler", cl)
 | 
			
		||||
                System.setProperty("platform.compiler.cpp11", cl)
 | 
			
		||||
                logger.quiet("Task ${thisTask.name} overrides compiler '${currentCompiler}' with '${cl}'.")
 | 
			
		||||
            }
 | 
			
		||||
            doLast {
 | 
			
		||||
 | 
			
		||||
@ -102,18 +102,16 @@ ENDIF()
 | 
			
		||||
 | 
			
		||||
IF(${SD_EXTENSION} MATCHES "avx2")
 | 
			
		||||
    message("Extension AVX2 enabled.")
 | 
			
		||||
    #-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1
 | 
			
		||||
    set(ARCH_TUNE "${ARCH_TUNE} -DSD_F16C=true -DF_AVX2=true")
 | 
			
		||||
    set(ARCH_TUNE "${ARCH_TUNE} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mprefetchwt1 -DSD_F16C=true -DF_AVX2=true")
 | 
			
		||||
ELSEIF(${SD_EXTENSION} MATCHES "avx512")
 | 
			
		||||
        message("Extension AVX512 enabled.")
 | 
			
		||||
        # we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
 | 
			
		||||
        #-mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq  -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves
 | 
			
		||||
        set(ARCH_TUNE "${ARCH_TUNE} -DSD_F16C=true -DF_AVX512=true")
 | 
			
		||||
        set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq  -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
 | 
			
		||||
ENDIF()
 | 
			
		||||
 | 
			
		||||
if (NOT WIN32)
 | 
			
		||||
        # we don't want this definition for msvc
 | 
			
		||||
        set(ARCH_TUNE "${ARCH_TUNE} -march=${SD_ARCH} -mtune=${ARCH_TYPE}")
 | 
			
		||||
        set(ARCH_TUNE "-march=${SD_ARCH} -mtune=${ARCH_TYPE}")
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD)
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ dependencies {
 | 
			
		||||
    implementation group: "org.bytedeco", name: "tensorflow"
 | 
			
		||||
    testRuntimeOnly group: "org.bytedeco", name: "tensorflow", classifier: buildTarget
 | 
			
		||||
    if(buildTarget.contains("windows") || buildTarget.contains("linux")) {
 | 
			
		||||
        testRuntimeOnly group: "org.bytedeco", name: 'tensorflow', classifier: "${buildTarget}-gpu", version: ''
 | 
			
		||||
        testRuntimeOnly group: "org.bytedeco", name: "tensorflow", classifier: "${buildTarget}-gpu"
 | 
			
		||||
    }
 | 
			
		||||
    implementation "commons-io:commons-io"
 | 
			
		||||
    implementation "com.google.code.gson:gson"
 | 
			
		||||
 | 
			
		||||
@ -87,7 +87,7 @@ ext {
 | 
			
		||||
            cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget
 | 
			
		||||
            cudaTestRuntime group: "org.bytedeco", name: "cuda"
 | 
			
		||||
            cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget
 | 
			
		||||
            //cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
 | 
			
		||||
            cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
 | 
			
		||||
            cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements"))
 | 
			
		||||
            /*
 | 
			
		||||
            cudaTestRuntime(project(":cavis-native:cavis-native-lib")) {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user