Compare commits
109 Commits
Author | SHA1 | Date |
---|---|---|
Brian Rosenberger | 035c196dfb | |
Brian Rosenberger | 1601a7189b | |
Brian Rosenberger | 927aa54563 | |
Brian Rosenberger | 06b1e4ab7a | |
Brian Rosenberger | acb098e8d7 | |
Brian Rosenberger | c37412cf34 | |
Brian Rosenberger | ae6f7f3e31 | |
Brian Rosenberger | 9740eb2566 | |
Brian Rosenberger | e2cbfacce5 | |
Brian Rosenberger | a1f5bba4ee | |
Brian Rosenberger | 9db53b96e1 | |
Brian Rosenberger | 5b6e37c721 | |
Brian Rosenberger | 94299b56be | |
Brian Rosenberger | 143b316755 | |
Brian Rosenberger | 91ce34cd77 | |
Brian Rosenberger | 8a98975252 | |
Brian Rosenberger | 235a8037ce | |
Brian Rosenberger | d5d7c5b6d3 | |
Brian Rosenberger | 0d0dc2755d | |
Brian Rosenberger | f8f9829308 | |
Brian Rosenberger | f9cef691fc | |
Brian Rosenberger | d8b6be8e66 | |
Brian Rosenberger | 6bb03b49a6 | |
Brian Rosenberger | 4e8a92b80f | |
Brian Rosenberger | 00128a11c2 | |
Brian Rosenberger | 8afae7a7f8 | |
Brian Rosenberger | 6f60f122cb | |
Brian Rosenberger | 8f0187c12d | |
Brian Rosenberger | d398ac64c8 | |
Brian Rosenberger | a2cc2c2263 | |
Brian Rosenberger | d28df16edf | |
Brian Rosenberger | 923b70edf8 | |
Brian Rosenberger | 5f2258b710 | |
Brian Rosenberger | d6dc72fc67 | |
Brian Rosenberger | 45933c6008 | |
Brian Rosenberger | 0d97ce3222 | |
Brian Rosenberger | d6a821b5e8 | |
Brian Rosenberger | 6509eaecf1 | |
Brian Rosenberger | 1575a27192 | |
Brian Rosenberger | 8e2f95c8fa | |
Brian Rosenberger | fad3057408 | |
Brian Rosenberger | aec8fb21ca | |
Brian Rosenberger | 8a303fe478 | |
Brian Rosenberger | df89eaf45a | |
Brian Rosenberger | 153d3fc674 | |
Brian Rosenberger | ab11499c76 | |
Brian Rosenberger | 33d855303b | |
Brian Rosenberger | 50eb2915bc | |
Brian Rosenberger | 1090aed6a2 | |
Brian Rosenberger | 87485c2d37 | |
Brian Rosenberger | 67d14b7ea8 | |
Brian Rosenberger | d44ddcacba | |
Brian Rosenberger | 1438c1fdae | |
Brian Rosenberger | 17c2306701 | |
Brian Rosenberger | 7b48bf1afb | |
Brian Rosenberger | 94da6843cd | |
Brian Rosenberger | 4793864178 | |
Brian Rosenberger | cefc2b9ea1 | |
Brian Rosenberger | 48ec7311bb | |
Brian Rosenberger | fd4a00e050 | |
Brian Rosenberger | 6256061378 | |
Brian Rosenberger | 1ff151d89a | |
Brian Rosenberger | 57f493f245 | |
Brian Rosenberger | 9d1fb9a279 | |
Brian Rosenberger | 7b73b05002 | |
Brian Rosenberger | 3e123cb4b8 | |
Brian Rosenberger | faa8ee5bc4 | |
Brian Rosenberger | 68e778bed0 | |
Brian Rosenberger | cbffab0a26 | |
Brian Rosenberger | 0525ea8f06 | |
Brian Rosenberger | 47771d5509 | |
Brian Rosenberger | 7edbe140ea | |
Brian Rosenberger | 48f20f1f27 | |
Brian Rosenberger | 54efcb8d47 | |
Brian Rosenberger | 667000df5b | |
Brian Rosenberger | 4d582263f0 | |
Brian Rosenberger | 3d29f98246 | |
Brian Rosenberger | dc2917857b | |
Brian Rosenberger | 97bcda699d | |
Brian Rosenberger | 0eb56ef45f | |
Brian Rosenberger | 2c0c3d01a0 | |
Brian Rosenberger | 43abd20b91 | |
Brian Rosenberger | 16e2e727e0 | |
Brian Rosenberger | c29d7172d3 | |
Brian Rosenberger | 97bf5b9baa | |
Brian Rosenberger | 1553f6ec78 | |
Brian Rosenberger | 6ef841e882 | |
Brian Rosenberger | 0d06e739ed | |
Brian Rosenberger | 150133602b | |
Brian Rosenberger | a63bee1b94 | |
Brian Rosenberger | 242cda372c | |
Brian Rosenberger | 3463b81d37 | |
Brian Rosenberger | d5eda7d4de | |
Brian Rosenberger | b33f5ea960 | |
Brian Rosenberger | ace9f74c31 | |
Brian Rosenberger | 9f1611609f | |
Brian Rosenberger | 4e4265c5c9 | |
Brian Rosenberger | ccba08e03f | |
Brian Rosenberger | 6e3fef4eb2 | |
Brian Rosenberger | 289305775c | |
Brian Rosenberger | cdd7eff0cf | |
Brian Rosenberger | 6da3d34fea | |
Brian Rosenberger | 796e3a6be0 | |
Brian Rosenberger | 9aa56f27f1 | |
Brian Rosenberger | 1dd926f8ec | |
Brian Rosenberger | 73d82f2b3a | |
Brian Rosenberger | b0b19107ed | |
Brian Rosenberger | 500a31d051 | |
Brian Rosenberger | 460ff4720d |
|
@ -1,20 +1,37 @@
|
||||||
FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
|
FROM nvidia/cuda:12.2.0-devel-ubuntu20.04
|
||||||
|
|
||||||
RUN apt-get update && \
|
ENV OS=ubuntu2004
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
ENV cudnn_version=8.9.4.25
|
||||||
#Build cmake version from source \
|
ENV cuda_version=cuda12.2
|
||||||
|
ENV CMAKE_VER=3.27.4
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y wget software-properties-common
|
||||||
|
|
||||||
|
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
|
||||||
|
|
||||||
|
RUN mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||||
|
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/7fa2af80.pub
|
||||||
|
RUN add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/ /"
|
||||||
|
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get upgrade -y && \
|
||||||
|
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk \
|
||||||
|
build-essential checkinstall zlib1g-dev libssl-dev git libpthread-stubs0-dev \
|
||||||
|
libcudnn8=${cudnn_version}-1+${cuda_version} libcudnn8-dev=${cudnn_version}-1+${cuda_version} \
|
||||||
|
cuda-drivers
|
||||||
|
|
||||||
|
|
||||||
|
#RUN apt-get install libcudnn8-samples=${cudnn_version}-1+${cuda_version}
|
||||||
|
#Build cmake version from source \
|
||||||
#RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \
|
#RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \
|
||||||
# tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \
|
# tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \
|
||||||
# ./bootstrap && make && make install
|
# ./bootstrap && make && make install
|
||||||
RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2-linux-x86_64.sh && \
|
RUN wget -nv https://github.com/Kitware/CMake/releases/download/v${CMAKE_VER}/cmake-${CMAKE_VER}-linux-x86_64.sh && \
|
||||||
mkdir /opt/cmake && sh ./cmake-3.24.2-linux-x86_64.sh --skip-license --prefix=/opt/cmake && ln -s /opt/cmake/bin/cmake /usr/bin/cmake && \
|
mkdir -p /opt/cmake && sh ./cmake-${CMAKE_VER}-linux-x86_64.sh --skip-license --prefix=/opt/cmake && ln -s /opt/cmake/bin/cmake /usr/bin/cmake && \
|
||||||
rm cmake-3.24.2-linux-x86_64.sh
|
rm cmake-${CMAKE_VER}-linux-x86_64.sh
|
||||||
|
RUN ln -s /usr/bin/make /usr/bin/gmake
|
||||||
|
|
||||||
|
|
||||||
RUN echo "/usr/local/cuda/compat/" >> /etc/ld.so.conf.d/cuda-driver.conf
|
#RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
||||||
|
#RUN echo "nameserver 9.9.9.9" >> /etc/resolv.conf
|
||||||
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
|
||||||
|
|
||||||
RUN ldconfig -p | grep cuda
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
name: Gitea Actions Demo
|
||||||
|
run-name: ${{ gitea.actor }} is testing out Gitea Actions 🚀
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
env:
|
||||||
|
OS: windows
|
||||||
|
cudnn_version: 8.9.4.25
|
||||||
|
cuda_version: cuda12.2
|
||||||
|
CMAKE_VER: 3.27.4
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
Explore-Gitea-Actions:
|
||||||
|
runs-on: windows
|
||||||
|
#container:
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: msys2 {0}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Install MSYS2
|
||||||
|
uses: msys2/setup-msys2@v2
|
||||||
|
with:
|
||||||
|
msystem: UCRT64
|
||||||
|
update: true
|
||||||
|
install: git tar gzip mingw-w64-ucrt-x86_64-gcc
|
||||||
|
|
||||||
|
# - run: Set-ExecutionPolicy -Scope CurrentUser -ExecutionPolicy Unrestricted
|
||||||
|
# - name: Check for CUDA
|
||||||
|
# run: |
|
||||||
|
# echo "Path: $env:PATH"
|
||||||
|
|
||||||
|
# - name: Install CUDA
|
||||||
|
# uses: Jimver/cuda-toolkit@v0.2.11
|
||||||
|
# id: cuda-toolkit
|
||||||
|
# with:
|
||||||
|
# cuda: '12.2.0'
|
||||||
|
# - run: nvcc -V
|
||||||
|
|
||||||
|
- name: Install CMake and Ninja
|
||||||
|
uses: lukka/get-cmake@dev/fix91
|
||||||
|
with:
|
||||||
|
useLocalCache: false
|
||||||
|
useCloudCache: false
|
||||||
|
#cmakeVersion: "~3.27.0"
|
||||||
|
cmakeVersion: latest
|
||||||
|
ninjaVersion: latest
|
||||||
|
|
||||||
|
- name: Execute Gradle build
|
||||||
|
run: |
|
||||||
|
cmd.exe /C ./gradlew.bat build \
|
||||||
|
--stacktrace \
|
||||||
|
-Pmavenuser=${{ secrets.MAVENUSER }} \
|
||||||
|
-Pmavenpass=${{ secrets.MAVENPASS }} \
|
||||||
|
-PossrhUsername=${{ secrets.OSSRHUSERNAME }} \
|
||||||
|
-PossrhPassword=${{ secrets.OSSRHPASSWORD }} \
|
||||||
|
-PCAVIS_CHIP=cpu,cuda -Pskip-native=false \
|
||||||
|
>buildlog.txt 2>&1
|
||||||
|
|
||||||
|
- name: Upload log
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
if: success() || failure() # run this step even if previous step failed
|
||||||
|
with:
|
||||||
|
name: Build-Log
|
||||||
|
path: buildlog.txt
|
||||||
|
|
||||||
|
- run: echo "This job's status is ${{ job.status }}."
|
|
@ -0,0 +1,79 @@
|
||||||
|
name: Gitea Actions Demo
|
||||||
|
run-name: ${{ gitea.actor }} is testing out Gitea Actions 🚀
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
env:
|
||||||
|
OS: ubuntu2004
|
||||||
|
cudnn_version: 8.9.4.25
|
||||||
|
cuda_version: cuda12.2
|
||||||
|
CMAKE_VER: 3.27.4
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
Explore-Gitea-Actions:
|
||||||
|
runs-on: ubuntu-20.04:docker://nvidia/cuda:12.2.0-devel-ubuntu20.04
|
||||||
|
steps:
|
||||||
|
- run: echo "The job was automatically triggered by a ${{ gitea.event_name }} event."
|
||||||
|
- run: echo "This job is now running on a ${{ runner.os }} server hosted by Gitea!"
|
||||||
|
- run: echo "The name of your branch is ${{ gitea.ref }} and your repository is ${{ gitea.repository }}."
|
||||||
|
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- run: echo "💡 The ${{ gitea.repository }} repository has been cloned to the runner."
|
||||||
|
- run: echo "🖥️ The workflow is now ready to test your code on the runner."
|
||||||
|
|
||||||
|
- name: List files in the repository
|
||||||
|
run: |
|
||||||
|
ls ${{ gitea.workspace }}
|
||||||
|
|
||||||
|
- name: Update initial docker image with apt-get
|
||||||
|
run: |
|
||||||
|
apt-get -qq update && DEBIAN_FRONTEND=noninteractive apt-get -qq install -y wget software-properties-common && \
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/$OS/x86_64/cuda-$OS.pin && \
|
||||||
|
mv cuda-$OS.pin /etc/apt/preferences.d/cuda-repository-pin-600 && \
|
||||||
|
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/$OS/x86_64/7fa2af80.pub && \
|
||||||
|
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/$OS/x86_64/3bf863cc.pub && \
|
||||||
|
add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/$OS/x86_64/ /" && \
|
||||||
|
|
||||||
|
apt-get -qq update && apt-get -qq upgrade -y && \
|
||||||
|
DEBIAN_FRONTEND=noninteractive apt-get -qq install -y \
|
||||||
|
build-essential checkinstall zlib1g-dev libssl-dev git libpthread-stubs0-dev \
|
||||||
|
libcudnn8=$cudnn_version-1+$cuda_version libcudnn8-dev=$cudnn_version-1+$cuda_version \
|
||||||
|
libblas{3,-dev} liblapack{3,-dev} libatlas-base-dev libopenblas-dev && \
|
||||||
|
wget -q https://developer.download.nvidia.com/compute/cuda/12.2.1/local_installers/cuda_12.2.1_535.86.10_linux.run && \
|
||||||
|
sh cuda_12.2.1_535.86.10_linux.run --silent --toolkit
|
||||||
|
|
||||||
|
- name: Setup Java
|
||||||
|
uses: actions/setup-java@v3
|
||||||
|
with:
|
||||||
|
distribution: 'temurin' # See 'Supported distributions' for available options
|
||||||
|
java-version: '11'
|
||||||
|
cache: 'gradle'
|
||||||
|
|
||||||
|
- name: Install CMake and Ninja
|
||||||
|
uses: lukka/get-cmake@latest
|
||||||
|
with:
|
||||||
|
useLocalCache: true
|
||||||
|
useCloudCache: false
|
||||||
|
cmakeVersion: "~3.27.0"
|
||||||
|
ninjaVersion: latest
|
||||||
|
|
||||||
|
- name: Execute Gradle build
|
||||||
|
run: |
|
||||||
|
sh ./gradlew build \
|
||||||
|
--stacktrace \
|
||||||
|
-Pmavenuser=${{ secrets.MAVENUSER }} \
|
||||||
|
-Pmavenpass=${{ secrets.MAVENPASS }} \
|
||||||
|
-PossrhUsername=${{ secrets.OSSRHUSERNAME }} \
|
||||||
|
-PossrhPassword=${{ secrets.OSSRHPASSWORD }} \
|
||||||
|
-PCAVIS_CHIP=cpu,cuda -Pskip-native=false \
|
||||||
|
>buildlog.log 2>&1
|
||||||
|
|
||||||
|
- name: Upload log
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
if: success() || failure() # run this step even if previous step failed
|
||||||
|
with:
|
||||||
|
name: my-artifact
|
||||||
|
path: buildlog.log
|
||||||
|
|
||||||
|
- run: echo "This job's status is ${{ job.status }}."
|
|
@ -96,3 +96,4 @@ bruai4j-native-common/cmake*
|
||||||
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
|
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
|
||||||
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
|
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
|
||||||
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF
|
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF
|
||||||
|
/.metadata/
|
||||||
|
|
|
@ -26,7 +26,7 @@ pipeline {
|
||||||
dir '.docker'
|
dir '.docker'
|
||||||
label 'linux && docker && cuda'
|
label 'linux && docker && cuda'
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||||
args '--gpus all' //needed for test only, you can build without GPU
|
//args '--gpus all' //--needed for test only, you can build without GPU
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,19 +57,31 @@ pipeline {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stage('test-linux-cuda') {
|
stage('test-linux-cuda') {
|
||||||
|
/* agent {
|
||||||
|
dockerfile {
|
||||||
|
filename 'Dockerfile'
|
||||||
|
dir '.docker'
|
||||||
|
label 'linux && docker && cuda && cudart'
|
||||||
|
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||||
|
args '--gpus all' //--needed for test only, you can build without GPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal_Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
steps {
|
steps {/*
|
||||||
withGradle {
|
withGradle {
|
||||||
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
|
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||||
}
|
}
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,12 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.nd4j.nativeblas;
|
package net.brutex.ai;
|
||||||
|
|
||||||
public class Dummy {
|
public class LoaderTest {
|
||||||
|
|
||||||
|
public static void main(String[] args){
|
||||||
|
System.load("C:\\Users\\brian\\_projects\\deeplearning4j\\cavis-native\\cavis-native-lib"
|
||||||
|
+ "\\build\\generated\\sources\\javacpp\\cuda\\windows-x86_64-avx2\\jnind4jcuda.dll");
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -22,7 +22,9 @@
|
||||||
package net.brutex.ai.nd4j.tests;
|
package net.brutex.ai.nd4j.tests;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.javacpp.Loader;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -36,13 +38,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
public class LoadBackendTests {
|
public class LoadBackendTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadBackend() throws NoSuchFieldException, IllegalAccessException {
|
public void loadBackend() throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
|
||||||
// check if Nd4j is there
|
log.info("get number of GPUs {}", Nd4jEnvironment.getEnvironment().getNumGpus());
|
||||||
Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
|
|
||||||
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
|
||||||
sysPathsField.setAccessible(true);
|
|
||||||
sysPathsField.set(null, null);
|
|
||||||
//System.loadLibrary("jnind4jcpu");
|
|
||||||
log.info("Backend: {}", Nd4j.getBackend().buildInfo());
|
log.info("Backend: {}", Nd4j.getBackend().buildInfo());
|
||||||
double d1 = 2.0;
|
double d1 = 2.0;
|
||||||
double d2 = 5.0;
|
double d2 = 5.0;
|
||||||
|
@ -52,4 +49,10 @@ public class LoadBackendTests {
|
||||||
Number n = res.sumNumber();
|
Number n = res.sumNumber();
|
||||||
assertEquals(n.doubleValue(), 7.0, String.format("Addition of two scalar values %g and %g", d1, d2));
|
assertEquals(n.doubleValue(), 7.0, String.format("Addition of two scalar values %g and %g", d1, d2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void loadCudaDLL() {
|
||||||
|
System.load(
|
||||||
|
"C:\\Users\\brian\\_projects\\deeplearning4j\\cavis-native\\cavis-native-lib\\build\\generated\\sources\\javacpp\\cuda\\windows-x86_64-avx2\\jnind4jcuda.dll");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,8 +37,6 @@ import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.datavec.image.transform.*;
|
import org.datavec.image.transform.*;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -48,27 +46,24 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static net.brutex.gan.App2Config.BATCHSIZE;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class App2 {
|
public class App2 {
|
||||||
|
|
||||||
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
||||||
|
static final float COLORSPACE = 255f;
|
||||||
static final int DIMENSIONS = 28;
|
static final int DIMENSIONS = 28;
|
||||||
static final int CHANNELS = 1;
|
static final int CHANNELS = 1;
|
||||||
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
||||||
|
final int OUTPUT_PER_PANEL = 10;
|
||||||
|
|
||||||
final boolean BIAS = true;
|
final boolean BIAS = true;
|
||||||
|
|
||||||
|
static final int BATCHSIZE=128;
|
||||||
|
|
||||||
private JFrame frame2, frame;
|
private JFrame frame2, frame;
|
||||||
static final String OUTPUT_DIR = "d:/out/";
|
static final String OUTPUT_DIR = "d:/out/";
|
||||||
|
@ -81,7 +76,7 @@ public class App2 {
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
||||||
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans3"), NativeImageLoader.getALLOWED_FORMATS());
|
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||||
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||||
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
||||||
|
@ -134,94 +129,12 @@ public class App2 {
|
||||||
|
|
||||||
log.info("Generator Summary:\n{}", gen.summary());
|
log.info("Generator Summary:\n{}", gen.summary());
|
||||||
log.info("GAN Summary:\n{}", gan.summary());
|
log.info("GAN Summary:\n{}", gan.summary());
|
||||||
dis.addTrainingListeners(new PerformanceListener(3, true, "DIS"));
|
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||||
//gen.addTrainingListeners(new PerformanceListener(3, true, "GEN")); //is never trained separately from GAN
|
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
|
||||||
gan.addTrainingListeners(new PerformanceListener(3, true, "GAN"));
|
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
||||||
/*
|
|
||||||
Thread vt =
|
|
||||||
new Thread(
|
|
||||||
new Runnable() {
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
while (true) {
|
|
||||||
visualize(0, 0, gen);
|
|
||||||
try {
|
|
||||||
Thread.sleep(10000);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
vt.start();
|
|
||||||
*/
|
|
||||||
|
|
||||||
App2Display display = new App2Display();
|
int j = 0;
|
||||||
//Repack training data with new fake/real label. Original MNist has 10 labels, one for each digit
|
for (int i = 0; i < 51; i++) { //epoch
|
||||||
DataSet data = null;
|
|
||||||
int j =0;
|
|
||||||
for(int i=0;i<App2Config.EPOCHS;i++) {
|
|
||||||
log.info("Epoch {}", i);
|
|
||||||
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
|
|
||||||
while (trainData.hasNext()) {
|
|
||||||
j++;
|
|
||||||
INDArray real = trainData.next().getFeatures();
|
|
||||||
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
|
|
||||||
|
|
||||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1,
|
|
||||||
Nd4j.rand(BATCHSIZE, App2Config.INPUT));
|
|
||||||
//sigmoid output is -1 to 1
|
|
||||||
fake.addi(1f).divi(2f);
|
|
||||||
|
|
||||||
if (j % 50 == 1) {
|
|
||||||
display.visualize(new INDArray[] {fake}, App2Config.OUTPUT_PER_PANEL, false);
|
|
||||||
display.visualize(new INDArray[] {real}, App2Config.OUTPUT_PER_PANEL, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
DataSet realSet = new DataSet(real, label_real);
|
|
||||||
DataSet fakeSet = new DataSet(fake, label_fake);
|
|
||||||
|
|
||||||
//start next round if there are not enough images left to have a full batchsize dataset
|
|
||||||
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
|
|
||||||
log.warn("Your total number of input images is not a multiple of {}, "
|
|
||||||
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
//if(real.length()/BATCHSIZE!=784) break;
|
|
||||||
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
|
|
||||||
|
|
||||||
}
|
|
||||||
//fit the discriminator
|
|
||||||
dis.fit(data);
|
|
||||||
dis.fit(data);
|
|
||||||
// Update the discriminator in the GAN network
|
|
||||||
updateGan(gen, dis, gan);
|
|
||||||
|
|
||||||
//reset the training data and fit the complete GAN
|
|
||||||
if (trainData.resetSupported()) {
|
|
||||||
trainData.reset();
|
|
||||||
} else {
|
|
||||||
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
|
||||||
}
|
|
||||||
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_real));
|
|
||||||
|
|
||||||
if (trainData.resetSupported()) {
|
|
||||||
trainData.reset();
|
|
||||||
} else {
|
|
||||||
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info("Updated GAN's generator from gen.");
|
|
||||||
updateGen(gen, gan);
|
|
||||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
|
||||||
}
|
|
||||||
//vt.stop();
|
|
||||||
|
|
||||||
/*
|
|
||||||
int j;
|
|
||||||
for (int i = 0; i < App2Config.EPOCHS; i++) { //epoch
|
|
||||||
j=0;
|
|
||||||
while (trainData.hasNext()) {
|
while (trainData.hasNext()) {
|
||||||
j++;
|
j++;
|
||||||
DataSet next = trainData.next();
|
DataSet next = trainData.next();
|
||||||
|
@ -299,8 +212,6 @@ public class App2 {
|
||||||
log.info("Updated GAN's generator from gen.");
|
log.info("Updated GAN's generator from gen.");
|
||||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -313,11 +224,110 @@ public class App2 {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
||||||
|
if (isOrig) {
|
||||||
|
frame.setTitle("Viz Original");
|
||||||
|
} else {
|
||||||
|
frame.setTitle("Generated");
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
|
JPanel panelx = new JPanel();
|
||||||
|
|
||||||
|
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
for(int i = 0; i<batchElements; i++) {
|
||||||
|
panelx.add(getImage(sample, i, isOrig));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
frame.add(panelx, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.setMinimumSize(new Dimension(300, 20));
|
||||||
|
frame.pack();
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||||
|
final BufferedImage bi;
|
||||||
|
if(CHANNELS >1) {
|
||||||
|
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
||||||
|
} else {
|
||||||
|
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
||||||
|
}
|
||||||
|
final int imageSize = DIMENSIONS * DIMENSIONS;
|
||||||
|
final int offset = batchElement * imageSize;
|
||||||
|
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||||
|
|
||||||
|
//Image in NCHW - channels first format
|
||||||
|
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||||
|
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
|
||||||
|
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
|
||||||
|
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
||||||
|
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
|
||||||
|
bi.getRaster().setSample(x, y, c, f_pxl);
|
||||||
|
pxl++; //next item in INDArray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||||
|
return new JLabel(scaled);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||||
|
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Save the images to disk
|
||||||
|
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||||
|
|
||||||
|
log.debug("Images saved successfully.");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("Error saving the images: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||||
|
File directory = new File(outputDirectory);
|
||||||
|
if (!directory.exists()) {
|
||||||
|
directory.mkdir();
|
||||||
|
}
|
||||||
|
|
||||||
|
File outputFile = new File(directory, fileName);
|
||||||
|
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage imageToBufferedImage(Image image) {
|
||||||
|
if (image instanceof BufferedImage) {
|
||||||
|
return (BufferedImage) image;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a buffered image with the same dimensions and transparency as the original image
|
||||||
|
BufferedImage bufferedImage;
|
||||||
|
if (CHANNELS > 1) {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||||
|
} else {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw the original image onto the buffered image
|
||||||
|
Graphics2D g2d = bufferedImage.createGraphics();
|
||||||
|
g2d.drawImage(image, 0, 0, null);
|
||||||
|
g2d.dispose();
|
||||||
|
|
||||||
|
return bufferedImage;
|
||||||
|
}
|
||||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
|
@ -331,41 +341,4 @@ public class App2 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void testDiskriminator() throws IOException {
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(App2Config.discriminator());
|
|
||||||
net.init();
|
|
||||||
net.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
|
||||||
DataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
|
||||||
|
|
||||||
DataSet data = null;
|
|
||||||
for(int i=0;i<App2Config.EPOCHS;i++) {
|
|
||||||
log.info("Epoch {}", i);
|
|
||||||
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
|
|
||||||
while (trainData.hasNext()) {
|
|
||||||
INDArray real = trainData.next().getFeatures();
|
|
||||||
long[] l = new long[]{BATCHSIZE, real.length() / BATCHSIZE};
|
|
||||||
INDArray fake = Nd4j.rand(l );
|
|
||||||
|
|
||||||
DataSet realSet = new DataSet(real, label_real);
|
|
||||||
DataSet fakeSet = new DataSet(fake, label_fake);
|
|
||||||
if(real.length()/BATCHSIZE!=784) break;
|
|
||||||
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
|
|
||||||
|
|
||||||
}
|
|
||||||
net.fit(data);
|
|
||||||
trainData.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
long[] l = new long[]{BATCHSIZE, 784};
|
|
||||||
INDArray fake = Nd4j.rand(l );
|
|
||||||
DataSet fakeSet = new DataSet(fake, label_fake);
|
|
||||||
data = DataSet.merge(Arrays.asList(data, fakeSet));
|
|
||||||
ExistingDataSetIterator iter = new ExistingDataSetIterator(data);
|
|
||||||
Evaluation eval = net.evaluate(iter);
|
|
||||||
log.info( "\n" + eval.confusionMatrix());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,17 +36,10 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
public class App2Config {
|
public class App2Config {
|
||||||
|
|
||||||
public static final int INPUT = 100;
|
public static final int INPUT = 100;
|
||||||
public static final int BATCHSIZE=150;
|
|
||||||
public static final int X_DIM = 28;
|
public static final int X_DIM = 28;
|
||||||
public static final int Y_DIM = 28;
|
public static final int y_DIM = 28;
|
||||||
public static final int CHANNELS = 1;
|
public static final int CHANNELS = 1;
|
||||||
public static final int EPOCHS = 50;
|
|
||||||
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
||||||
public static final IUpdater UPDATER_DIS = Adam.builder().learningRate(0.02).beta1(0.5).build();
|
|
||||||
public static final boolean SHOW_GENERATED = true;
|
|
||||||
public static final float COLORSPACE = 255f;
|
|
||||||
|
|
||||||
final static int OUTPUT_PER_PANEL = 10;
|
|
||||||
|
|
||||||
static LayerConfiguration[] genLayerConfig() {
|
static LayerConfiguration[] genLayerConfig() {
|
||||||
return new LayerConfiguration[] {
|
return new LayerConfiguration[] {
|
||||||
|
@ -165,7 +158,7 @@ public class App2Config {
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(100)
|
.gradientNormalizationThreshold(100)
|
||||||
.seed(42)
|
.seed(42)
|
||||||
.updater(UPDATER_DIS)
|
.updater(UPDATER)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
.weightNoise(null)
|
.weightNoise(null)
|
||||||
|
|
|
@ -1,160 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package net.brutex.gan;
|
|
||||||
|
|
||||||
import com.google.inject.Singleton;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import javax.imageio.ImageIO;
|
|
||||||
import javax.swing.*;
|
|
||||||
import java.awt.*;
|
|
||||||
import java.awt.color.ColorSpace;
|
|
||||||
import java.awt.image.BufferedImage;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static net.brutex.gan.App2.OUTPUT_DIR;
|
|
||||||
import static net.brutex.gan.App2Config.*;
|
|
||||||
@Slf4j
|
|
||||||
@Singleton
|
|
||||||
public class App2Display {
|
|
||||||
|
|
||||||
private final JFrame frame = new JFrame();
|
|
||||||
private final App2GUI display = new App2GUI();
|
|
||||||
|
|
||||||
private final JPanel real_panel;
|
|
||||||
private final JPanel fake_panel;
|
|
||||||
|
|
||||||
|
|
||||||
public App2Display() {
|
|
||||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
|
||||||
frame.setContentPane(display.getOverall_panel());
|
|
||||||
frame.setMinimumSize(new Dimension(300, 20));
|
|
||||||
frame.pack();
|
|
||||||
frame.setVisible(true);
|
|
||||||
real_panel = display.getReal_panel();
|
|
||||||
fake_panel = display.getGen_panel();
|
|
||||||
real_panel.setLayout(new GridLayout(4, 4, 8, 8));
|
|
||||||
fake_panel.setLayout(new GridLayout(4, 4, 8, 8));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void visualize(INDArray[] samples, int batchElements, boolean isOrig) {
|
|
||||||
for (INDArray sample : samples) {
|
|
||||||
for(int i = 0; i<batchElements; i++) {
|
|
||||||
final Image img = this.getImage(sample, i, isOrig);
|
|
||||||
final ImageIcon icon = new ImageIcon(img);
|
|
||||||
if(isOrig) {
|
|
||||||
if(real_panel.getComponents().length>=OUTPUT_PER_PANEL) {
|
|
||||||
real_panel.remove(0);
|
|
||||||
}
|
|
||||||
real_panel.add(new JLabel(icon));
|
|
||||||
} else {
|
|
||||||
if(fake_panel.getComponents().length>=OUTPUT_PER_PANEL) {
|
|
||||||
fake_panel.remove(0);
|
|
||||||
}
|
|
||||||
fake_panel.add(new JLabel(icon));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
frame.pack();
|
|
||||||
frame.repaint();
|
|
||||||
}
|
|
||||||
|
|
||||||
public Image getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
|
||||||
final BufferedImage bi;
|
|
||||||
if(CHANNELS >1) {
|
|
||||||
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
|
||||||
} else {
|
|
||||||
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
|
||||||
}
|
|
||||||
final int imageSize = X_DIM * Y_DIM;
|
|
||||||
final int offset = batchElement * imageSize;
|
|
||||||
int pxl = offset * CHANNELS; //where to start in the INDArray
|
|
||||||
|
|
||||||
//Image in NCHW - channels first format
|
|
||||||
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
|
||||||
for (int y = 0; y < X_DIM; y++) { // step through the columns x
|
|
||||||
for (int x = 0; x < Y_DIM; x++) { //step through the rows y
|
|
||||||
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
|
||||||
if(isOrig) log.trace("'{}.'{} Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, isOrig ? "Real" : "Fake", x, y, c, pxl, f_pxl);
|
|
||||||
bi.getRaster().setSample(x, y, c, f_pxl);
|
|
||||||
pxl++; //next item in INDArray
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ImageIcon orig = new ImageIcon(bi);
|
|
||||||
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
|
||||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
|
||||||
//if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
|
||||||
return imageScaled;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
|
||||||
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Save the images to disk
|
|
||||||
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
|
||||||
|
|
||||||
log.debug("Images saved successfully.");
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("Error saving the images: {}", e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
|
||||||
File directory = new File(outputDirectory);
|
|
||||||
if (!directory.exists()) {
|
|
||||||
directory.mkdir();
|
|
||||||
}
|
|
||||||
|
|
||||||
File outputFile = new File(directory, fileName);
|
|
||||||
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static BufferedImage imageToBufferedImage(Image image) {
|
|
||||||
if (image instanceof BufferedImage) {
|
|
||||||
return (BufferedImage) image;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a buffered image with the same dimensions and transparency as the original image
|
|
||||||
BufferedImage bufferedImage;
|
|
||||||
if (CHANNELS > 1) {
|
|
||||||
bufferedImage =
|
|
||||||
new BufferedImage(
|
|
||||||
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
|
||||||
} else {
|
|
||||||
bufferedImage =
|
|
||||||
new BufferedImage(
|
|
||||||
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Draw the original image onto the buffered image
|
|
||||||
Graphics2D g2d = bufferedImage.createGraphics();
|
|
||||||
g2d.drawImage(image, 0, 0, null);
|
|
||||||
g2d.dispose();
|
|
||||||
|
|
||||||
return bufferedImage;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
package net.brutex.gan;
|
|
||||||
|
|
||||||
import javax.swing.JPanel;
|
|
||||||
import javax.swing.JSplitPane;
|
|
||||||
import javax.swing.JLabel;
|
|
||||||
import java.awt.BorderLayout;
|
|
||||||
|
|
||||||
public class App2GUI extends JPanel {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
private JPanel overall_panel;
|
|
||||||
private JPanel real_panel;
|
|
||||||
private JPanel gen_panel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create the panel.
|
|
||||||
*/
|
|
||||||
public App2GUI() {
|
|
||||||
|
|
||||||
overall_panel = new JPanel();
|
|
||||||
add(overall_panel);
|
|
||||||
|
|
||||||
JSplitPane splitPane = new JSplitPane();
|
|
||||||
overall_panel.add(splitPane);
|
|
||||||
|
|
||||||
JPanel p1 = new JPanel();
|
|
||||||
splitPane.setLeftComponent(p1);
|
|
||||||
p1.setLayout(new BorderLayout(0, 0));
|
|
||||||
|
|
||||||
JLabel lblNewLabel = new JLabel("Generator");
|
|
||||||
p1.add(lblNewLabel, BorderLayout.NORTH);
|
|
||||||
|
|
||||||
gen_panel = new JPanel();
|
|
||||||
p1.add(gen_panel, BorderLayout.SOUTH);
|
|
||||||
|
|
||||||
JPanel p2 = new JPanel();
|
|
||||||
splitPane.setRightComponent(p2);
|
|
||||||
p2.setLayout(new BorderLayout(0, 0));
|
|
||||||
|
|
||||||
JLabel lblNewLabel_1 = new JLabel("Real");
|
|
||||||
p2.add(lblNewLabel_1, BorderLayout.NORTH);
|
|
||||||
|
|
||||||
real_panel = new JPanel();
|
|
||||||
p2.add(real_panel, BorderLayout.SOUTH);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public JPanel getOverall_panel() {
|
|
||||||
return overall_panel;
|
|
||||||
}
|
|
||||||
public JPanel getReal_panel() {
|
|
||||||
return real_panel;
|
|
||||||
}
|
|
||||||
public JPanel getGen_panel() {
|
|
||||||
return gen_panel;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -23,7 +23,7 @@ buildscript {
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
dependencies {
|
dependencies {
|
||||||
classpath "com.vanniktech:gradle-dependency-graph-generator-plugin:0.6.0"
|
classpath "com.vanniktech:gradle-dependency-graph-generator-plugin:0.8.0"
|
||||||
classpath 'com.google.gradle:osdetector-gradle-plugin:1.7.0'
|
classpath 'com.google.gradle:osdetector-gradle-plugin:1.7.0'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,19 +8,22 @@ ext {
|
||||||
javacppPlatform = osdetector.classifier
|
javacppPlatform = osdetector.classifier
|
||||||
}
|
}
|
||||||
|
|
||||||
def javacpp = [version: "1.5.7", presetsVersion: "1.5.7"]
|
def javacpp = [version: "1.5.9", presetsVersion: "1.5.9"]
|
||||||
def hdf5 = [version: "1.12.1"]
|
def hdf5 = [version: "1.14.1"]
|
||||||
def jackson = [version: "2.13.4"]
|
def jackson = [version: "2.13.4"]
|
||||||
def cuda = [version: "11.6"]
|
def cuda = [version: "12.1"]
|
||||||
def cudnn = [version: "8.3"]
|
def cudnn = [version: "8.9"]
|
||||||
def openblas = [version: "0.3.19"]
|
def openblas = [version: "0.3.23"]
|
||||||
def numpy = [version: "1.22.2"]
|
def numpy = [version: "1.24.3"]
|
||||||
|
def tensorflow_lite = [version: "2.12.0"]
|
||||||
def tensorflow = [version: "1.15.5"]
|
def tensorflow = [version: "1.15.5"]
|
||||||
def cpython = [version: "3.10.2"]
|
def tensorrt = [version: "8.6.1.6"]
|
||||||
|
def cpython = [version: "3.11.3"]
|
||||||
|
def mkl = [version:"2023.1"]
|
||||||
|
|
||||||
def javacv = [version:"1.5.7"]
|
def javacv = [version:"1.5.9"]
|
||||||
def opencv = [version: "4.5.5"]
|
def opencv = [version: "4.7.0"]
|
||||||
def leptonica = [version: "1.83.0"] //fix, only in javacpp 1.5.9
|
def leptonica = [version: "1.83.0"]
|
||||||
def junit = [version: "5.9.1"]
|
def junit = [version: "5.9.1"]
|
||||||
|
|
||||||
def flatbuffers = [version: "1.10.0"]
|
def flatbuffers = [version: "1.10.0"]
|
||||||
|
@ -41,17 +44,13 @@ dependencies {
|
||||||
|
|
||||||
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
||||||
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
||||||
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
|
||||||
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
||||||
|
|
||||||
|
|
||||||
constraints {
|
constraints {
|
||||||
api enforcedPlatform("io.netty:netty-bom:${netty.version}")
|
api ("io.netty:netty-bom:${netty.version}")
|
||||||
api enforcedPlatform("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
api ("com.fasterxml.jackson:jackson-bom:${jackson.version}")
|
||||||
api enforcedPlatform("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
api ("com.squareup.okhttp3:okhttp-bom:${okhttp3.version}")
|
||||||
//api enforcedPlatform("com.fasterxml.jackson.core:jackson-annotations:${jackson.version}")
|
|
||||||
//api "com.squareup.okhttp3:okhttp:${okhttp3}.version"
|
|
||||||
//api "com.squareup.okhttp3:logging-interceptor:${okhttp3}.version"
|
|
||||||
|
|
||||||
api 'com.google.guava:guava:30.1-jre'
|
api 'com.google.guava:guava:30.1-jre'
|
||||||
api "com.google.protobuf:protobuf-java:3.15.6"
|
api "com.google.protobuf:protobuf-java:3.15.6"
|
||||||
|
@ -59,18 +58,6 @@ dependencies {
|
||||||
api "com.google.protobuf:protobuf-java-util:3.15.6"
|
api "com.google.protobuf:protobuf-java-util:3.15.6"
|
||||||
api "com.google.flatbuffers:flatbuffers-java:${flatbuffers.version}"
|
api "com.google.flatbuffers:flatbuffers-java:${flatbuffers.version}"
|
||||||
|
|
||||||
/*
|
|
||||||
api "com.fasterxml.jackson.core:jackson-core:${jackson.version}"
|
|
||||||
api "com.fasterxml.jackson.core:jackson-databind:${jackson.version}"
|
|
||||||
api "com.fasterxml.jackson.core:jackson-annotations:${jackson.version}"
|
|
||||||
|
|
||||||
api "com.fasterxml.jackson.dataformat:jackson-dataformat-xml:${jackson.version}"
|
|
||||||
*/
|
|
||||||
// api "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${jackson.version}"
|
|
||||||
// api "com.fasterxml.jackson.datatype:jackson-datatype-joda:${jackson.version}"
|
|
||||||
// api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
|
|
||||||
|
|
||||||
|
|
||||||
api "org.projectlombok:lombok:1.18.28"
|
api "org.projectlombok:lombok:1.18.28"
|
||||||
|
|
||||||
/*Logging*/
|
/*Logging*/
|
||||||
|
@ -81,7 +68,7 @@ dependencies {
|
||||||
api "ch.qos.logback:logback-classic:1.2.3"
|
api "ch.qos.logback:logback-classic:1.2.3"
|
||||||
api 'ch.qos.logback:logback-core:1.2.3'
|
api 'ch.qos.logback:logback-core:1.2.3'
|
||||||
|
|
||||||
|
/* commons */
|
||||||
api 'commons-io:commons-io:2.5'
|
api 'commons-io:commons-io:2.5'
|
||||||
api 'commons-codec:commons-codec:1.11'
|
api 'commons-codec:commons-codec:1.11'
|
||||||
api 'commons-net:commons-net:3.6'
|
api 'commons-net:commons-net:3.6'
|
||||||
|
@ -118,24 +105,23 @@ dependencies {
|
||||||
api "org.bytedeco:javacv:${javacv.version}"
|
api "org.bytedeco:javacv:${javacv.version}"
|
||||||
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:leptonica-platform:${leptonica.version}-1.5.9"
|
api "org.bytedeco:openblas-platform:${openblas.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:leptonica:${leptonica.version}-1.5.9"
|
api "org.bytedeco:leptonica-platform:${leptonica.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
//api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:linux-x86_64"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
//api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}"
|
api "org.bytedeco:mkl:${mkl.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:mkl:2022.0-${javacpp.presetsVersion}"
|
//api "org.bytedeco:tensorflow:${tensorflow.version}-1.5.8" //not available for javacpp 1.5.9 ?
|
||||||
api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}"
|
//api "org.bytedeco:tensorflow-platform:${tensorflow.version}-1.5.8"
|
||||||
|
//api "org.bytedeco:tensorflow-lite:${tensorflow_lite.version}-${javacpp.presetsVersion}"
|
||||||
|
//api "org.bytedeco:tensorflow-lite-platform:${tensorflow_lite.version}-${javacpp.presetsVersion}"
|
||||||
|
api "org.bytedeco:tensorrt:${tensorrt.version}-${javacpp.presetsVersion}"
|
||||||
|
api "org.bytedeco:tensorrt-platform:${tensorrt.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
//implementation "org.bytedeco:cpython-platform:3.9.6-1.5.6"
|
|
||||||
//implementation "org.bytedeco:numpy-platform:1.21.1-1.5.6"
|
|
||||||
|
|
||||||
/* Apache Spark */
|
/* Apache Spark */
|
||||||
api "org.apache.spark:spark-core_${scalaVersion}:${spark.version}"
|
api "org.apache.spark:spark-core_${scalaVersion}:${spark.version}"
|
||||||
|
@ -169,16 +155,6 @@ dependencies {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
publishing {
|
|
||||||
publications {
|
|
||||||
myPlatform(MavenPublication) {
|
|
||||||
from components.javaPlatform
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
tasks.withType(GenerateModuleMetadata).configureEach {
|
tasks.withType(GenerateModuleMetadata).configureEach {
|
||||||
// The value 'enforced-platform' is provided in the validation
|
// The value 'enforced-platform' is provided in the validation
|
||||||
// error message you got
|
// error message you got
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
plugins {
|
||||||
|
id 'java-library'
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
|
implementation projects.cavisNative.cavisNativeBlas
|
||||||
|
implementation "org.bytedeco:javacpp"
|
||||||
|
implementation group: "org.bytedeco", name: 'openblas-platform'
|
||||||
|
}
|
|
@ -19,5 +19,4 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
#org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
@ -13,7 +13,7 @@ dependencies {
|
||||||
implementation projects.cavisNative.cavisNativeCommon
|
implementation projects.cavisNative.cavisNativeCommon
|
||||||
implementation projects.cavisDnn.cavisDnnApi
|
implementation projects.cavisDnn.cavisDnnApi
|
||||||
implementation projects.cavisDnn.cavisDnnCommon
|
implementation projects.cavisDnn.cavisDnnCommon
|
||||||
|
implementation projects.cavisNative.cavisNativeCpuPresets
|
||||||
|
|
||||||
implementation (projects.cavisNative.cavisNativeLib) {
|
implementation (projects.cavisNative.cavisNativeLib) {
|
||||||
capabilities {
|
capabilities {
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
plugins {
|
||||||
|
id 'java-library'
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
|
implementation projects.cavisNative.cavisNativeBlas
|
||||||
|
implementation "org.bytedeco:javacpp"
|
||||||
|
}
|
|
@ -37,7 +37,10 @@ import java.util.List;
|
||||||
* @author saudet
|
* @author saudet
|
||||||
*/
|
*/
|
||||||
@Properties(target = "org.nd4j.nativeblas.Nd4jCuda", helper = "org.nd4j.nativeblas.cuda.Nd4jCudaHelper",
|
@Properties(target = "org.nd4j.nativeblas.Nd4jCuda", helper = "org.nd4j.nativeblas.cuda.Nd4jCudaHelper",
|
||||||
value = {@Platform(define = "LIBND4J_ALL_OPS", include = {
|
value = {
|
||||||
|
@Platform(
|
||||||
|
define = "LIBND4J_ALL_OPS",
|
||||||
|
include = {
|
||||||
"array/DataType.h",
|
"array/DataType.h",
|
||||||
"array/DataBuffer.h",
|
"array/DataBuffer.h",
|
||||||
"array/PointerDeallocator.h",
|
"array/PointerDeallocator.h",
|
||||||
|
@ -105,7 +108,7 @@ import java.util.List;
|
||||||
"ops/declarable/CustomOperations.h",
|
"ops/declarable/CustomOperations.h",
|
||||||
"build_info.h",
|
"build_info.h",
|
||||||
},
|
},
|
||||||
exclude = {"ops/declarable/headers/activations.h",
|
exclude = {"ops/declarable/headers/activations.h",
|
||||||
"ops/declarable/headers/boolean.h",
|
"ops/declarable/headers/boolean.h",
|
||||||
"ops/declarable/headers/broadcastable.h",
|
"ops/declarable/headers/broadcastable.h",
|
||||||
"ops/declarable/headers/convo.h",
|
"ops/declarable/headers/convo.h",
|
||||||
|
@ -125,12 +128,16 @@ import java.util.List;
|
||||||
"cnpy/cnpy.h"
|
"cnpy/cnpy.h"
|
||||||
},
|
},
|
||||||
compiler = {"cpp11", "nowarnings"},
|
compiler = {"cpp11", "nowarnings"},
|
||||||
library = "jnind4jcuda", link = "nd4jcuda", preload = "nd4jcuda"),
|
library = "jnind4jcuda",
|
||||||
@Platform(value = "linux", preload = "gomp@.1", preloadpath = {"/lib64/", "/lib/", "/usr/lib64/", "/usr/lib/"}),
|
link = {"nd4jcuda"}),
|
||||||
|
//preload = "nd4jcuda"),
|
||||||
|
|
||||||
|
@Platform(value = "linux", preload = "gomp@.1", preloadpath = {"/lib64/", "/lib/", "/usr/lib64/", "/usr/lib/", "/usr/local/cuda/lib64"}),
|
||||||
@Platform(value = "linux-armhf", preloadpath = {"/usr/arm-linux-gnueabihf/lib/", "/usr/lib/arm-linux-gnueabihf/"}),
|
@Platform(value = "linux-armhf", preloadpath = {"/usr/arm-linux-gnueabihf/lib/", "/usr/lib/arm-linux-gnueabihf/"}),
|
||||||
@Platform(value = "linux-arm64", preloadpath = {"/usr/aarch64-linux-gnu/lib/", "/usr/lib/aarch64-linux-gnu/"}),
|
@Platform(value = "linux-arm64", preloadpath = {"/usr/aarch64-linux-gnu/lib/", "/usr/lib/aarch64-linux-gnu/"}),
|
||||||
@Platform(value = "linux-ppc64", preloadpath = {"/usr/powerpc64-linux-gnu/lib/", "/usr/powerpc64le-linux-gnu/lib/", "/usr/lib/powerpc64-linux-gnu/", "/usr/lib/powerpc64le-linux-gnu/"}),
|
@Platform(value = "linux-ppc64", preloadpath = {"/usr/powerpc64-linux-gnu/lib/", "/usr/powerpc64le-linux-gnu/lib/", "/usr/lib/powerpc64-linux-gnu/", "/usr/lib/powerpc64le-linux-gnu/"}),
|
||||||
@Platform(value = "windows", preload = {"libwinpthread-1", "libgcc_s_seh-1", "libgomp-1", "libstdc++-6", "nd4jcuda"}) })
|
@Platform(value = "windows", preload = {"libwinpthread-1", "libgcc_s_seh-1", "libgomp-1", "libstdc++-6"})
|
||||||
|
})
|
||||||
public class Nd4jCudaPresets implements LoadEnabled, InfoMapper {
|
public class Nd4jCudaPresets implements LoadEnabled, InfoMapper {
|
||||||
|
|
||||||
@Override public void init(ClassProperties properties) {
|
@Override public void init(ClassProperties properties) {
|
||||||
|
@ -143,14 +150,19 @@ public class Nd4jCudaPresets implements LoadEnabled, InfoMapper {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
/*
|
||||||
String[] libs = {"cudart", "cublasLt", "cublas", "curand", "cusolver", "cusparse", "cudnn",
|
String[] libs = {"cudart", "cublasLt", "cublas", "curand", "cusolver", "cusparse", "cudnn",
|
||||||
"cudnn_ops_infer", "cudnn_ops_train", "cudnn_adv_infer",
|
"cudnn_ops_infer", "cudnn_ops_train", "cudnn_adv_infer",
|
||||||
"cudnn_adv_train", "cudnn_cnn_infer", "cudnn_cnn_train"};
|
"cudnn_adv_train", "cudnn_cnn_infer", "cudnn_cnn_train"};
|
||||||
|
|
||||||
|
*/
|
||||||
|
// test no preload
|
||||||
|
String[] libs = {};
|
||||||
for (String lib : libs) {
|
for (String lib : libs) {
|
||||||
if (platform.startsWith("linux")) {
|
if (platform.startsWith("linux")) {
|
||||||
lib += lib.startsWith("cudnn") ? "@.8" : lib.equals("curand") ? "@.10" : lib.equals("cudart") ? "@.11.0" : "@.11";
|
lib += lib.startsWith("cudnn") ? "@.8" : lib.equals("curand") ? "@.10" : lib.equals("cufft") ? "@.11" : "@.12";
|
||||||
} else if (platform.startsWith("windows")) {
|
} else if (platform.startsWith("windows")) {
|
||||||
lib += lib.startsWith("cudnn") ? "64_8" : lib.equals("curand") ? "64_10" : lib.equals("cudart") ? "64_110" : "64_11";
|
lib += lib.startsWith("cudnn") ? "64_8" : lib.equals("cufft") ? "64_11" : lib.equals("cusolver") ? "64_11" : lib.equals("curand") ? "64_10" : "64_12";
|
||||||
} else {
|
} else {
|
||||||
continue; // no CUDA
|
continue; // no CUDA
|
||||||
}
|
}
|
||||||
|
@ -158,9 +170,9 @@ public class Nd4jCudaPresets implements LoadEnabled, InfoMapper {
|
||||||
preloads.add(i++, lib);
|
preloads.add(i++, lib);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (i > 0) {
|
//if (i > 0) {
|
||||||
resources.add("/org/bytedeco/cuda/");
|
resources.add("/org/bytedeco/cuda/");
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
|
@ -0,0 +1,23 @@
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ******************************************************************************
|
||||||
|
# *
|
||||||
|
# * This program and the accompanying materials are made available under the
|
||||||
|
# * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
# *
|
||||||
|
# * See the NOTICE file distributed with this work for additional
|
||||||
|
# * information regarding copyright ownership.
|
||||||
|
# * Unless required by applicable law or agreed to in writing, software
|
||||||
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# * License for the specific language governing permissions and limitations
|
||||||
|
# * under the License.
|
||||||
|
# *
|
||||||
|
# * SPDX-License-Identifier: Apache-2.0
|
||||||
|
# *****************************************************************************
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
org.nd4j.linalg.cpu.nativecpu.compression.CpuThreshold
|
|
@ -0,0 +1,23 @@
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ******************************************************************************
|
||||||
|
# *
|
||||||
|
# * This program and the accompanying materials are made available under the
|
||||||
|
# * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
# *
|
||||||
|
# * See the NOTICE file distributed with this work for additional
|
||||||
|
# * information regarding copyright ownership.
|
||||||
|
# * Unless required by applicable law or agreed to in writing, software
|
||||||
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# * License for the specific language governing permissions and limitations
|
||||||
|
# * under the License.
|
||||||
|
# *
|
||||||
|
# * SPDX-License-Identifier: Apache-2.0
|
||||||
|
# *****************************************************************************
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
org.nd4j.linalg.jcublas.JCublasBackend
|
||||||
|
#org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
@ -0,0 +1,21 @@
|
||||||
|
#
|
||||||
|
# /* ******************************************************************************
|
||||||
|
# *
|
||||||
|
# *
|
||||||
|
# * This program and the accompanying materials are made available under the
|
||||||
|
# * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
# *
|
||||||
|
# * See the NOTICE file distributed with this work for additional
|
||||||
|
# * information regarding copyright ownership.
|
||||||
|
# * Unless required by applicable law or agreed to in writing, software
|
||||||
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# * License for the specific language governing permissions and limitations
|
||||||
|
# * under the License.
|
||||||
|
# *
|
||||||
|
# * SPDX-License-Identifier: Apache-2.0
|
||||||
|
# ******************************************************************************/
|
||||||
|
#
|
||||||
|
|
||||||
|
iamax_strided = 1
|
|
@ -0,0 +1,22 @@
|
||||||
|
#
|
||||||
|
# /* ******************************************************************************
|
||||||
|
# *
|
||||||
|
# *
|
||||||
|
# * This program and the accompanying materials are made available under the
|
||||||
|
# * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
# *
|
||||||
|
# * See the NOTICE file distributed with this work for additional
|
||||||
|
# * information regarding copyright ownership.
|
||||||
|
# * Unless required by applicable law or agreed to in writing, software
|
||||||
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# * License for the specific language governing permissions and limitations
|
||||||
|
# * under the License.
|
||||||
|
# *
|
||||||
|
# * SPDX-License-Identifier: Apache-2.0
|
||||||
|
# ******************************************************************************/
|
||||||
|
#
|
||||||
|
|
||||||
|
org.nd4j.linalg.api.resources.maxallocated= 2000000000
|
||||||
|
org.nd4j.linalg.api.resources.memoryratio=0.5
|
|
@ -10,22 +10,22 @@ ext {
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation platform(projects.cavisCommonPlatform)
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
|
|
||||||
//implementation project(":cavis-native:cavis-native-blas")
|
|
||||||
implementation projects.cavisNative.cavisNativeBlas
|
implementation projects.cavisNative.cavisNativeBlas
|
||||||
|
|
||||||
implementation group: "org.bytedeco", name: "cuda"
|
implementation group: "org.bytedeco", name: "cuda"
|
||||||
implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget
|
||||||
implementation group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
|
//implementation group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist"
|
||||||
|
|
||||||
implementation group: "org.bytedeco", name: "javacpp"
|
implementation group: "org.bytedeco", name: "javacpp"
|
||||||
implementation group: "org.bytedeco", name: "javacpp", classifier: buildTarget
|
implementation group: "org.bytedeco", name: "javacpp", classifier: buildTarget
|
||||||
|
|
||||||
implementation(project(path: ":cavis-native:cavis-native-lib")) {
|
implementation projects.cavisNative.cavisNativeCudaPresets
|
||||||
|
implementation(project(":cavis-native:cavis-native-lib")) {
|
||||||
capabilities {
|
capabilities {
|
||||||
it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version:project.version)
|
requireCapability("${project.group}:cavis-native-lib-cuda-support:${project.version}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
implementation project(":cavis-native:cavis-native-common")
|
implementation project(":cavis-native:cavis-native-common")
|
||||||
implementation project(":cavis-dnn:cavis-dnn-api")
|
implementation project(":cavis-dnn:cavis-dnn-api")
|
||||||
implementation project(":cavis-dnn:cavis-dnn-common")
|
implementation project(":cavis-dnn:cavis-dnn-common")
|
||||||
|
@ -36,3 +36,9 @@ dependencies {
|
||||||
implementation "org.apache.commons:commons-lang3"
|
implementation "org.apache.commons:commons-lang3"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tasks.named("compileJava").configure {
|
||||||
|
dependsOn ":cavis-native:cavis-native-lib:javacppCudaSupportBuildParser",
|
||||||
|
":cavis-native:cavis-native-lib:cudaJar"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.nd4j.jita.constant;
|
package org.nd4j.jita.constant;
|
||||||
|
|
||||||
import lombok.extern.log4j.Log4j2;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
|
@ -3,13 +3,14 @@ cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
project(libnd4j)
|
project(libnd4j)
|
||||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
|
||||||
set (CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}")
|
set (CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
message("CMAKE MODULE PATH IS ${CMAKE_MODULE_PATH}")
|
message("CMAKE MODULE PATH IS ${CMAKE_MODULE_PATH}")
|
||||||
|
|
||||||
#ensure we create lib files
|
#ensure we create lib files
|
||||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
|
#set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
|
||||||
|
|
||||||
|
|
||||||
option(SD_NATIVE "Optimize for build machine (might not work on others)" OFF)
|
option(SD_NATIVE "Optimize for build machine (might not work on others)" OFF)
|
||||||
|
@ -25,6 +26,12 @@ set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FOR
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 14)
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
|
|
||||||
|
set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively
|
||||||
|
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded")
|
||||||
|
|
||||||
#///////////////////////////////////////////////////////////////////////////////
|
#///////////////////////////////////////////////////////////////////////////////
|
||||||
# genCompilation: Generates cpp, cu files
|
# genCompilation: Generates cpp, cu files
|
||||||
# INPUT:
|
# INPUT:
|
||||||
|
@ -120,8 +127,8 @@ endfunction()
|
||||||
|
|
||||||
|
|
||||||
if (SD_CUDA)
|
if (SD_CUDA)
|
||||||
#enable_language(CUDA)
|
find_package(CUDAToolkit 12.2 REQUIRED)
|
||||||
find_package(CUDAToolkit 11.4 REQUIRED)
|
enable_language(CUDA)
|
||||||
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
|
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
|
||||||
message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}")
|
message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}")
|
||||||
message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}")
|
message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}")
|
||||||
|
@ -136,8 +143,7 @@ else()
|
||||||
set(DEFAULT_ENGINE "samediff::ENGINE_CPU")
|
set(DEFAULT_ENGINE "samediff::ENGINE_CPU")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively
|
|
||||||
#set(MSVC_RT_LIB "MultiThreadedDLL")
|
|
||||||
|
|
||||||
set(SD_X86_BUILD false)
|
set(SD_X86_BUILD false)
|
||||||
|
|
||||||
|
@ -155,10 +161,10 @@ elseif (APPLE)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true")
|
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true")
|
||||||
elseif(WIN32)
|
elseif(WIN32)
|
||||||
set(SD_X86_BUILD true)
|
set(SD_X86_BUILD false)
|
||||||
if (SD_CUDA)
|
if (SD_CUDA)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc")
|
#set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc")
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC")
|
set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC")
|
||||||
|
@ -362,7 +368,7 @@ if(SD_BUILD_TESTS)
|
||||||
# tests are always compiled with all ops included
|
# tests are always compiled with all ops included
|
||||||
set(SD_ALL_OPS true)
|
set(SD_ALL_OPS true)
|
||||||
set(SD_BUILD_MINIFIER true)
|
set(SD_BUILD_MINIFIER true)
|
||||||
add_subdirectory(tests_cpu)
|
add_subdirectory(src/test/cpp/tests_cpu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@ -370,7 +376,6 @@ if (MSVC_DEV)
|
||||||
set(SD_BUILD_MINIFIER false)
|
set(SD_BUILD_MINIFIER false)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
set (CMAKE_INSTALL_PREFIX $ENV{ND4J_HOME}/bruai4j-native/bruai4j-native-common/src/main/resources)
|
|
||||||
|
|
||||||
# Set package information
|
# Set package information
|
||||||
set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "Native operations for nd4j.")
|
set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "Native operations for nd4j.")
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import org.gradle.api.internal.java.DefaultJavaPlatformExtension
|
||||||
|
import org.gradle.api.plugins.internal.DefaultJavaPluginExtension
|
||||||
import org.gradle.api.publish.maven.internal.publisher.MavenRemotePublisher
|
import org.gradle.api.publish.maven.internal.publisher.MavenRemotePublisher
|
||||||
import org.gradle.language.nativeplatform.internal.Dimensions
|
import org.gradle.language.nativeplatform.internal.Dimensions
|
||||||
|
|
||||||
|
@ -44,8 +46,7 @@ buildscript {
|
||||||
logger.info("Setting properties for task '{}' to '{}'", tsk.getName(), pf)
|
logger.info("Setting properties for task '{}' to '{}'", tsk.getName(), pf)
|
||||||
return pf
|
return pf
|
||||||
}
|
}
|
||||||
|
} // End of ext block
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
@ -64,104 +65,106 @@ buildscript {
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id 'java-library'
|
id 'java-library'
|
||||||
id 'org.bytedeco.gradle-javacpp-build' version "1.5.7"
|
id 'org.bytedeco.gradle-javacpp-build' version "1.5.9" //version "1.5.10-SNAPSHOT"
|
||||||
id 'maven-publish'
|
id 'maven-publish'
|
||||||
id 'signing'
|
id 'signing'
|
||||||
}
|
}
|
||||||
|
|
||||||
chipList.each {thisChip ->
|
chipList.each {String thisChip ->
|
||||||
sourceSets.register("${thisChip}Support") {
|
/*sourceSets.register(thisChip) {
|
||||||
java {
|
java {
|
||||||
srcDirs = ['src/main/java', "${buildDir}/generated/sources/javacpp/${thisChip}//${javacppPlatform}${javacppPlatformExtension}/"]
|
srcDirs = ["${projectDir}/src/main/java/"]
|
||||||
include "org/nd4j/nativeblas/${thisChip}/Nd4j${thisChip.capitalize()}Helper.java"
|
include "org/nd4j/nativeblas/${thisChip}/Nd4j${thisChip.capitalize()}Helper.java"
|
||||||
include "org/nd4j/nativeblas/${thisChip}/Nd4j${thisChip.capitalize()}Presets.java"
|
include "org/nd4j/nativeblas/${thisChip}/Nd4j${thisChip.capitalize()}Presets.java"
|
||||||
|
}
|
||||||
|
}*/
|
||||||
|
sourceSets.register("${thisChip}").configure {
|
||||||
|
java {
|
||||||
|
srcDirs = ["${buildDir}/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/"]
|
||||||
include "org/nd4j/nativeblas/Nd4j${thisChip.capitalize()}.java"
|
include "org/nd4j/nativeblas/Nd4j${thisChip.capitalize()}.java"
|
||||||
}
|
}
|
||||||
it.compiledBy("javacpp${thisChip.capitalize()}SupportBuildCommand",
|
compiledBy "javacpp${thisChip.capitalize()}SupportBuildCompiler"
|
||||||
"javacpp${thisChip.capitalize()}SupportBuildCompiler")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
//if(osdetector.os.startsWith("windows")) {
|
sourceSets {
|
||||||
sourceSets {
|
main {
|
||||||
main {
|
java {
|
||||||
java {
|
srcDirs = new HashSet<>();
|
||||||
srcDirs = ['src/main/java']
|
include 'org/nd4j/nativeblas/Dummy.java'
|
||||||
include 'org/nd4j/nativeblas/Dummy.java'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// This block registers the cpu and cuda features and creates
|
||||||
|
// i. e. the {chip}Implementation
|
||||||
java {
|
java {
|
||||||
chipList.each {thisChip ->
|
chipList.each {thisChip ->
|
||||||
registerFeature("${thisChip}Support") {
|
registerFeature("${thisChip}Support") {
|
||||||
usingSourceSet(sourceSets.findByName("${thisChip}Support"))
|
usingSourceSet(sourceSets.findByName("${thisChip}"))
|
||||||
capability(project.group, "cavis-native-lib-${thisChip}-support", project.version)
|
capability(project.group, "cavis-native-lib-${thisChip}-support", project.version)
|
||||||
//withJavadocJar()
|
//withJavadocJar()
|
||||||
//withSourcesJar()
|
//withSourcesJar()
|
||||||
}
|
}}}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api platform(project(':cavis-common-platform'))
|
|
||||||
implementation "org.bytedeco:javacpp"
|
|
||||||
implementation group: "org.bytedeco", name: "javacpp", classifier: "${javacppPlatform}"
|
|
||||||
|
|
||||||
if(withCuda()) {
|
if(withCuda()) {
|
||||||
cudaSupportImplementation platform(project(':cavis-common-platform'))
|
cudaImplementation platform(project(':cavis-common-platform'))
|
||||||
cudaSupportImplementation project(":cavis-dnn:cavis-dnn-api")
|
|
||||||
cudaSupportImplementation project(":cavis-dnn:cavis-dnn-common")
|
//cudaImplementation project(":cavis-dnn:cavis-dnn-api")
|
||||||
cudaSupportImplementation project(":cavis-native:cavis-native-blas")
|
//cudaImplementation project(":cavis-dnn:cavis-dnn-common")
|
||||||
cudaSupportImplementation project(":cavis-native:cavis-native-common")
|
cudaImplementation project(":cavis-native:cavis-native-blas")
|
||||||
cudaSupportImplementation "commons-io:commons-io"
|
//cudaImplementation project(":cavis-native:cavis-native-common")
|
||||||
cudaSupportImplementation group: "org.bytedeco", name: "openblas"
|
//cudaImplementation "commons-io:commons-io"
|
||||||
cudaSupportImplementation group: "org.bytedeco", name: "openblas", classifier: "${javacppPlatform}"
|
//cudaImplementation "org.bytedeco:openblas"
|
||||||
cudaSupportImplementation group: "org.bytedeco", name: "cuda"
|
//cudaImplementation "org.bytedeco:openblas::${javacppPlatform}"
|
||||||
cudaSupportImplementation group: "org.bytedeco", name: "cuda", classifier: "${javacppPlatform}"
|
//cudaImplementation "org.bytedeco:cuda"
|
||||||
cudaSupportImplementation "org.apache.logging.log4j:log4j-core:2.17.0"
|
//cudaImplementation "org.bytedeco:cuda::${javacppPlatform}"
|
||||||
cudaSupportImplementation "com.google.guava:guava:14.0.1"
|
//cudaImplementation "org.apache.logging.log4j:log4j-core:2.17.0"
|
||||||
cudaSupportImplementation "org.apache.commons:commons-lang3"
|
//cudaImplementation "com.google.guava:guava:14.0.1"
|
||||||
cudaSupportImplementation "org.apache.commons:commons-math3"
|
//cudaImplementation "org.apache.commons:commons-lang3"
|
||||||
cudaSupportImplementation "com.google.flatbuffers:flatbuffers-java"
|
//cudaImplementation "org.apache.commons:commons-math3"
|
||||||
cudaSupportImplementation 'javax.mail:javax.mail-api:1.6.2'
|
//cudaImplementation "com.google.flatbuffers:flatbuffers-java"
|
||||||
|
//cudaImplementation 'javax.mail:javax.mail-api:1.6.2'
|
||||||
|
cudaImplementation "org.bytedeco:javacpp"
|
||||||
|
cudaImplementation "org.bytedeco:javacpp::${javacppPlatform}"
|
||||||
|
cudaImplementation project(":cavis-native:cavis-native-cuda-presets")
|
||||||
|
|
||||||
|
//cudaGeneratedImplementation platform(project(':cavis-common-platform'))
|
||||||
|
//cudaGeneratedImplementation project(":cavis-native:cavis-native-blas")
|
||||||
|
//cudaGeneratedImplementation "org.bytedeco:javacpp"
|
||||||
|
//cudaGeneratedImplementation "org.bytedeco:javacpp::${javacppPlatform}"
|
||||||
|
//cudaGeneratedImplementation project(":cavis-native:cavis-native-cuda-presets")
|
||||||
}
|
}
|
||||||
|
|
||||||
if(withCpu()) {
|
if(withCpu()) {
|
||||||
cpuSupportImplementation platform(project(':cavis-common-platform'))
|
cpuImplementation platform(project(':cavis-common-platform'))
|
||||||
cpuSupportImplementation project(":cavis-dnn:cavis-dnn-api")
|
//cpuImplementation project(":cavis-dnn:cavis-dnn-api")
|
||||||
cpuSupportImplementation project(":cavis-dnn:cavis-dnn-common")
|
//cpuImplementation project(":cavis-dnn:cavis-dnn-common")
|
||||||
cpuSupportImplementation project(":cavis-native:cavis-native-blas")
|
cpuImplementation project(":cavis-native:cavis-native-blas")
|
||||||
cpuSupportImplementation project(":cavis-native:cavis-native-common")
|
//cpuImplementation project(":cavis-native:cavis-native-common")
|
||||||
cpuSupportImplementation "commons-io:commons-io"
|
//cpuImplementation "commons-io:commons-io"
|
||||||
cpuSupportImplementation group: "org.bytedeco", name: "openblas"
|
//cpuImplementation "org.bytedeco:opencv"
|
||||||
cpuSupportImplementation group: "org.bytedeco", name: "openblas", classifier: "${javacppPlatform}"
|
//cpuImplementation "org.bytedeco:opencv::${javacppPlatform}"
|
||||||
cpuSupportImplementation group: "org.bytedeco", name: "opencv"
|
//cpuImplementation "org.apache.logging.log4j:log4j-core:2.17.0"
|
||||||
cpuSupportImplementation group: "org.bytedeco", name: "opencv", classifier: "${javacppPlatform}"
|
//cpuImplementation "com.google.guava:guava:14.0.1"
|
||||||
cpuSupportImplementation "org.apache.logging.log4j:log4j-core:2.17.0"
|
//cpuImplementation "org.apache.commons:commons-lang3"
|
||||||
cpuSupportImplementation "com.google.guava:guava:14.0.1"
|
//cpuImplementation "org.apache.commons:commons-math3"
|
||||||
cpuSupportImplementation "org.apache.commons:commons-lang3"
|
//cpuImplementation "com.google.flatbuffers:flatbuffers-java"
|
||||||
cpuSupportImplementation "org.apache.commons:commons-math3"
|
//cpuImplementation 'javax.mail:javax.mail-api:1.6.2'
|
||||||
cpuSupportImplementation "com.google.flatbuffers:flatbuffers-java"
|
cpuImplementation "org.bytedeco:javacpp"
|
||||||
cpuSupportImplementation 'javax.mail:javax.mail-api:1.6.2'
|
cpuImplementation "org.bytedeco:javacpp::${javacppPlatform}"
|
||||||
}
|
// https://mvnrepository.com/artifact/org.bytedeco/openblas
|
||||||
|
cpuImplementation 'org.bytedeco:openblas:0.3.23-1.5.9'
|
||||||
|
|
||||||
implementation projects.cavisDnn.cavisDnnApi
|
|
||||||
implementation projects.cavisDnn.cavisDnnCommon
|
cpuImplementation project(":cavis-native:cavis-native-cpu-presets")
|
||||||
implementation project(":cavis-native:cavis-native-blas")
|
}
|
||||||
implementation project(":cavis-native:cavis-native-common")
|
|
||||||
implementation "commons-io:commons-io"
|
|
||||||
implementation "org.bytedeco:openblas"
|
|
||||||
implementation group: "org.bytedeco", name: "openblas", classifier: "${javacppPlatform}"
|
|
||||||
implementation "org.apache.logging.log4j:log4j-core"
|
|
||||||
implementation "com.google.guava:guava:14.0.1"
|
|
||||||
implementation "org.apache.commons:commons-lang3"
|
|
||||||
implementation "org.apache.commons:commons-math3"
|
|
||||||
implementation "com.google.flatbuffers:flatbuffers-java"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,40 +186,34 @@ task deepClean(type: Delete) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
tasks.withType(org.bytedeco.gradle.javacpp.BuildTask) {
|
tasks.withType(org.bytedeco.gradle.javacpp.BuildTask).configureEach { org.bytedeco.gradle.javacpp.BuildTask it ->
|
||||||
buildResource = [ "/org/bytedeco/openblas/${javacppPlatform}/",
|
/*
|
||||||
"/org/bytedeco/mkldnn/${javacppPlatform}/"]
|
it.buildResource = ["/org/bytedeco/openblas/${javacppPlatform}/",
|
||||||
|
"/org/bytedeco/mkldnn/${javacppPlatform}/"]
|
||||||
includeResource = ["/org/bytedeco/openblas/${javacppPlatform}/include/"]
|
|
||||||
|
|
||||||
linkResource = ["/org/bytedeco/openblas/${javacppPlatform}/",
|
|
||||||
"/org/bytedeco/openblas/${javacppPlatform}/lib/"]
|
|
||||||
|
|
||||||
//buildPath = [ org.bytedeco.javacpp.Loader.getCacheDir() ]
|
|
||||||
|
|
||||||
|
it.includeResource = ["/org/bytedeco/openblas/${javacppPlatform}/include/"]
|
||||||
|
|
||||||
|
it.linkResource = ["/org/bytedeco/openblas/${javacppPlatform}/",
|
||||||
|
"/org/bytedeco/openblas/${javacppPlatform}/lib/"]
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Disable the standard javacpp generated tasks and use own
|
// Disable the standard javacpp generated tasks and use own
|
||||||
// versions below. This allows to build for each variant
|
// versions below. This allows to build for each variant
|
||||||
|
|
||||||
[javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each {
|
[javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each {
|
||||||
it.enabled false
|
it.enabled false
|
||||||
}
|
}
|
||||||
|
|
||||||
chipList.each { thisChip ->
|
chipList.each { String thisChip ->
|
||||||
|
|
||||||
// 1)
|
// 1)
|
||||||
//Run the C++ compile first
|
//Run the C++ compile first
|
||||||
tasks.register("javacpp${thisChip.capitalize()}SupportBuildCommand", org.bytedeco.gradle.javacpp.BuildTask) {
|
tasks.register("javacpp${thisChip.capitalize()}SupportBuildCommand", org.bytedeco.gradle.javacpp.BuildTask) {org.bytedeco.gradle.javacpp.BuildTask it ->
|
||||||
if (project.hasProperty("skip-native") && project.getProperty("skip-native").equals("true")) {
|
|
||||||
enabled = false
|
|
||||||
}
|
|
||||||
dependsOn "processResources"
|
|
||||||
properties = getBuildPlatform( thisChip, it )
|
properties = getBuildPlatform( thisChip, it )
|
||||||
|
|
||||||
|
|
||||||
includePath = ["${projectDir}/src/main/cpp/blas/",
|
includePath = ["${projectDir}/src/main/cpp/blas/",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/src/main/include/",
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/src/main/include/",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/flatbuffers-src/include",
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/flatbuffers-src/include",
|
||||||
|
@ -226,19 +223,16 @@ chipList.each { thisChip ->
|
||||||
//No idea why this is here, but it looks like even for the javacppBuildCommand task,
|
//No idea why this is here, but it looks like even for the javacppBuildCommand task,
|
||||||
//there is a javacpp Loader actively determining platform etc.
|
//there is a javacpp Loader actively determining platform etc.
|
||||||
classOrPackageNames = ["org.nd4j.nativeblas.${thisChip}.Nd4j${thisChip.capitalize()}Presets"]
|
classOrPackageNames = ["org.nd4j.nativeblas.${thisChip}.Nd4j${thisChip.capitalize()}Presets"]
|
||||||
workingDirectory = projectDir
|
//workingDirectory = projectDir
|
||||||
//if the classpath is not set here, the javacpp classloader starts to look around
|
//if the classpath is not set here, the javacpp classloader starts to look around
|
||||||
//everywhere and causes java.io.IOExceptions: because files is being used by another process
|
//everywhere and causes java.io.IOExceptions: because files is being used by another process
|
||||||
classPath = [:]
|
//logger.quiet("Using compile classpath from configuration named '{}'", sourceSets.named(thisChip).get().getCompileClasspathConfigurationName())
|
||||||
classPath += ["${buildDir}/classes/java/${thisChip}Support/"]
|
|
||||||
//classPath += ["${buildDir}/classes/java/main/"]
|
classPath = sourceSets.named(thisChip).get().compileClasspath.collect()
|
||||||
|
|
||||||
/* Get VCVARS in case we want to build CUDA
|
/* Get VCVARS in case we want to build CUDA
|
||||||
* MinGW64 g++ on MSYS is used otherwise */
|
* MinGW64 g++ on MSYS is used otherwise */
|
||||||
if (thisChip.equals('cuda') && osdetector.os.startsWith("win")
|
if (thisChip.equals('cuda') && osdetector.os.startsWith("win") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
|
||||||
&& project.hasProperty("skip-native")
|
|
||||||
&& !project.getProperty("skip-native").equals("true")
|
|
||||||
&& !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
|
|
||||||
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && set"].execute()
|
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && set"].execute()
|
||||||
it.environmentVariables = it.environmentVariables ?: [:]
|
it.environmentVariables = it.environmentVariables ?: [:]
|
||||||
def lines = proc.text.split("\\r?\\n")
|
def lines = proc.text.split("\\r?\\n")
|
||||||
|
@ -246,14 +240,15 @@ chipList.each { thisChip ->
|
||||||
if (line.contains("=")) {
|
if (line.contains("=")) {
|
||||||
def parts = line.split("=")
|
def parts = line.split("=")
|
||||||
it.environmentVariables.put(parts[0], parts[1])
|
it.environmentVariables.put(parts[0], parts[1])
|
||||||
|
logger.debug("Added variable to environment: {} = {}", parts[0], parts[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
workingDirectory projectDir
|
||||||
if (thisChip.equals('cuda') && osdetector.os.startsWith("windows")) { //cuDNN requires CUDA
|
if (thisChip.equals('cuda') && osdetector.os.startsWith("windows")) { //cuDNN requires CUDA
|
||||||
it.buildCommand = ['sh', 'buildnativeoperations.sh',
|
it.buildCommand = ['sh', 'buildnativeoperations.sh',
|
||||||
'-V',
|
'-V',
|
||||||
'--build-type', 'release',
|
'--build-type', 'debug',
|
||||||
'--chip', thisChip,
|
'--chip', thisChip,
|
||||||
'--plattform', 'x86_64',
|
'--plattform', 'x86_64',
|
||||||
'--chip-extension', avxExtension,
|
'--chip-extension', avxExtension,
|
||||||
|
@ -280,24 +275,13 @@ chipList.each { thisChip ->
|
||||||
'-j', "${host_cores}",
|
'-j', "${host_cores}",
|
||||||
'--helper', 'mkldnn']
|
'--helper', 'mkldnn']
|
||||||
}
|
}
|
||||||
}
|
if(project.hasProperty("nativeTests")) it.buildCommand += "--tests"
|
||||||
|
|
||||||
|
|
||||||
//Create a task to (pre)compile the java presets (required for javacppBuildParser)
|
|
||||||
tasks.register("compile${thisChip.capitalize()}Support", JavaCompile) {
|
|
||||||
def thisSS = sourceSets.findByName("${thisChip}Support")
|
|
||||||
it.source = thisSS.allSource
|
|
||||||
it.classpath = thisSS.compileClasspath
|
|
||||||
it.destinationDirectory = file("${buildDir}/classes/java/${thisChip}Support/")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//Run the parser on the InfoMap in Nd4j$ChipPresets and listed header files in @Platform
|
//Run the parser on the InfoMap in Nd4j$ChipPresets and listed header files in @Platform
|
||||||
//Generates Nd4jCpu.java and/ or Nd4jCuda.java Java JNI code
|
//Generates Nd4jCpu.java and/ or Nd4jCuda.java Java JNI code
|
||||||
tasks.register("javacpp${thisChip.capitalize()}SupportBuildParser", org.bytedeco.gradle.javacpp.BuildTask) {
|
tasks.register("javacpp${thisChip.capitalize()}SupportBuildParser", org.bytedeco.gradle.javacpp.BuildTask) {
|
||||||
if (project.hasProperty("skip-native") && project.getProperty("skip-native").equals("true")) {
|
|
||||||
enabled = false
|
|
||||||
}
|
|
||||||
dependsOn "compile${thisChip.capitalize()}Support"
|
|
||||||
|
|
||||||
includePath = ["${projectDir}/src/main/cpp/blas/",
|
includePath = ["${projectDir}/src/main/cpp/blas/",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/src/main/include/",
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/src/main/include/",
|
||||||
|
@ -305,35 +289,25 @@ chipList.each { thisChip ->
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/cpu_features-src/include",
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/cpu_features-src/include",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/mkldnn-src/include"]
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/mkldnn-src/include"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
classOrPackageNames = ["org.nd4j.nativeblas.${thisChip}.Nd4j${thisChip.capitalize()}Presets"]
|
classOrPackageNames = ["org.nd4j.nativeblas.${thisChip}.Nd4j${thisChip.capitalize()}Presets"]
|
||||||
outputDirectory = file("${buildDir}/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/")
|
classPath = sourceSets.named(thisChip).get().compileClasspath.collect()
|
||||||
|
outputDirectory file("${buildDir}/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/")
|
||||||
classPath = sourceSets.getByName("${thisChip}Support").getRuntimeClasspath()
|
|
||||||
classPath += ["${buildDir}/classes/java/${thisChip}Support/"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Generates jnijavacpp.cpp and jniNativeLibrary.cpp, compiles and links it
|
// Generates jnijavacpp.cpp and jniNativeLibrary.cpp, compiles and links it
|
||||||
tasks.register("javacpp${thisChip.capitalize()}SupportBuildCompiler", org.bytedeco.gradle.javacpp.BuildTask) {
|
tasks.register("javacpp${thisChip.capitalize()}SupportBuildCompiler", org.bytedeco.gradle.javacpp.BuildTask) {org.bytedeco.gradle.javacpp.BuildTask it ->
|
||||||
if (project.hasProperty("skip-native") && project.getProperty("skip-native").equals("true")) {
|
|
||||||
enabled = false
|
|
||||||
}
|
|
||||||
def thisTask = (org.bytedeco.gradle.javacpp.BuildTask) it
|
|
||||||
thisTask.dependsOn = ["javacpp${thisChip.capitalize()}SupportBuildParser"]
|
|
||||||
|
|
||||||
thisTask.linkPath = ["${projectDir}/blasbuild/${thisChip}/${avxExtension}/output"]
|
linkPath = ["${projectDir}/blasbuild/${thisChip}/${avxExtension}/output"]
|
||||||
thisTask.includePath = ["${projectDir}/src/main/cpp/blas/",
|
includePath = ["${projectDir}/src/main/cpp/blas/",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/src/main/include/",
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/src/main/include/",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/flatbuffers-src/include",
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/flatbuffers-src/include",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/cpu_features-src/include",
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/cpu_features-src/include",
|
||||||
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/mkldnn-src/include"]
|
"${projectDir}/blasbuild/${thisChip}/${avxExtension}/mkldnn-src/include"]
|
||||||
|
|
||||||
thisTask.properties = getBuildPlatform( thisChip, thisTask )
|
properties = getBuildPlatform( thisChip, it )
|
||||||
|
|
||||||
if(thisChip.equals('cuda') && osdetector.os.startsWith("win") && project.hasProperty("skip-native")
|
if(thisChip.equals('cuda') && osdetector.os.startsWith("win") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
|
||||||
&& !project.getProperty("skip-native").equals("true") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
|
|
||||||
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute()
|
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute()
|
||||||
def outp = proc.text
|
def outp = proc.text
|
||||||
def cl = "\"" + outp.replace("\\", "\\\\").trim() + "\""
|
def cl = "\"" + outp.replace("\\", "\\\\").trim() + "\""
|
||||||
|
@ -342,7 +316,8 @@ chipList.each { thisChip ->
|
||||||
currentCompiler = System.getProperty("org.bytedeco.javacpp.platform.compiler")
|
currentCompiler = System.getProperty("org.bytedeco.javacpp.platform.compiler")
|
||||||
System.setProperty("org.bytedeco.javacpp.platform.compiler", cl)
|
System.setProperty("org.bytedeco.javacpp.platform.compiler", cl)
|
||||||
System.setProperty("platform.compiler.cpp11", cl)
|
System.setProperty("platform.compiler.cpp11", cl)
|
||||||
logger.quiet("Task ${thisTask.name} overrides compiler '${currentCompiler}' with '${cl}'.")
|
logger.quiet("Task ${name} overrides compiler '${currentCompiler}' with '${cl}'.")
|
||||||
|
|
||||||
}
|
}
|
||||||
doLast {
|
doLast {
|
||||||
//restore compiler
|
//restore compiler
|
||||||
|
@ -351,12 +326,12 @@ chipList.each { thisChip ->
|
||||||
//System.setProperty("org.bytedeco.javacpp.platform.compiler.cpp11", cl)
|
//System.setProperty("org.bytedeco.javacpp.platform.compiler.cpp11", cl)
|
||||||
|
|
||||||
proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && set"].execute()
|
proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && set"].execute()
|
||||||
thisTask.environmentVariables = thisTask.environmentVariables ?: [:]
|
environmentVariables = environmentVariables ?: [:]
|
||||||
def lines = proc.text.split("\\r?\\n")
|
def lines = proc.text.split("\\r?\\n")
|
||||||
for (def line in lines) {
|
for (def line in lines) {
|
||||||
if (line.contains("=")) {
|
if (line.contains("=")) {
|
||||||
def parts = line.split("=")
|
def parts = line.split("=")
|
||||||
thisTask.environmentVariables.put(parts[0], parts[1])
|
environmentVariables.put(parts[0], parts[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,32 +340,27 @@ chipList.each { thisChip ->
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
thisTask.buildPath = ["$buildDir/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/"]
|
buildPath = ["${buildDir}/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/"]
|
||||||
thisTask.copyLibs = true
|
copyLibs = true
|
||||||
thisTask.deleteJniFiles(false)
|
deleteJniFiles(false)
|
||||||
outputName = "jnind4j${thisChip}"
|
//outputName = "jnind4j${thisChip}"
|
||||||
thisTask.outputDirectory = file("$buildDir/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/")
|
outputDirectory = file("${buildDir}/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/")
|
||||||
thisTask.classOrPackageNames= ["org.nd4j.nativeblas.Nd4j${thisChip.capitalize()}"]
|
classOrPackageNames= ["org.nd4j.nativeblas.Nd4j${thisChip.capitalize()}"]
|
||||||
|
|
||||||
thisTask.configDirectory = file("${buildDir}/classes/java/${thisChip}Support/META-INF/native-image/${javacppPlatform}")
|
configDirectory = file("${buildDir}/classes/java/${thisChip}Support/META-INF/native-image/${javacppPlatform}")
|
||||||
|
classPath = sourceSets.named("${thisChip}").get().compileClasspath.collect()
|
||||||
//Need to set the classpath, so that external jars from the dependency list are resolved by the ClassLoader as well
|
classPath += "${buildDir}/classes/java/${thisChip}/"
|
||||||
thisTask.classPath = [:]
|
|
||||||
thisTask.classPath = ["${buildDir}/classes/java/${thisChip}Support"]
|
|
||||||
thisTask.classPath += sourceSets.findByName("${thisChip}Support").runtimeClasspath
|
|
||||||
//sourceSets.findByName("${thisChip}Support").runtimeClasspath.each{ s ->
|
|
||||||
// thisTask.classPath += s
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Jar with classifier
|
// Create Jar with classifier
|
||||||
tasks.getByName("${thisChip}SupportJar") { Jar thisTask ->
|
tasks.named("${thisChip}Jar").configure { Jar thisTask ->
|
||||||
dependsOn "javacpp${thisChip.capitalize()}SupportBuildCompiler"
|
dependsOn "javacpp${thisChip.capitalize()}SupportBuildCompiler"
|
||||||
dependsOn "javacpp${thisChip.capitalize()}SupportBuildCommand"
|
dependsOn "javacpp${thisChip.capitalize()}SupportBuildCommand"
|
||||||
|
|
||||||
//it.from sourceSets.getByName("${thisChip}Support").getOutput()
|
|
||||||
def spec = copySpec {
|
def spec = copySpec {
|
||||||
from(tasks.getByName("javacpp${thisChip.capitalize()}SupportBuildCompiler")) {
|
|
||||||
|
from(tasks.named("javacpp${thisChip.capitalize()}SupportBuildCompiler").get()) {
|
||||||
exclude { f ->
|
exclude { f ->
|
||||||
def exclude = f.file.isDirectory()
|
def exclude = f.file.isDirectory()
|
||||||
if(exclude) {
|
if(exclude) {
|
||||||
|
@ -402,8 +372,8 @@ chipList.each { thisChip ->
|
||||||
}
|
}
|
||||||
into "${javacppPlatform}/" //path within jar, we need it in a platform, that javacpp Loader understands
|
into "${javacppPlatform}/" //path within jar, we need it in a platform, that javacpp Loader understands
|
||||||
}
|
}
|
||||||
from(sourceSets.getByName("${thisChip}Support").getOutput()) {
|
from(sourceSets.named(thisChip).get().getOutput()) {
|
||||||
|
into "${javacppPlatform}/" //path within jar, we need it in a platform, that javacpp Loader understands
|
||||||
}
|
}
|
||||||
duplicatesStrategy DuplicatesStrategy.EXCLUDE
|
duplicatesStrategy DuplicatesStrategy.EXCLUDE
|
||||||
}
|
}
|
||||||
|
@ -415,34 +385,43 @@ chipList.each { thisChip ->
|
||||||
|
|
||||||
//Before we can compile the whole java part, we
|
//Before we can compile the whole java part, we
|
||||||
//need to generate the Nd4jXXX.java files first
|
//need to generate the Nd4jXXX.java files first
|
||||||
chipList.each { thisChip ->
|
tasks.named("compileJava").configure {enabled false}
|
||||||
tasks.findByName("compile${thisChip.capitalize()}SupportJava").each { t ->
|
|
||||||
t.dependsOn "javacpp${thisChip.capitalize()}SupportBuildParser"
|
chipList.each { String thisChip ->
|
||||||
|
//ensure full build process is running on "build"
|
||||||
|
tasks.named("build").configure {
|
||||||
|
dependsOn "javacpp${thisChip.capitalize()}SupportBuildCompiler"
|
||||||
|
}
|
||||||
|
//Compiles and links the generated jni code with the underlying native library
|
||||||
|
tasks.named("javacpp${thisChip.capitalize()}SupportBuildCompiler").configure {
|
||||||
|
dependsOn "javacpp${thisChip.capitalize()}SupportBuildParser"
|
||||||
|
}
|
||||||
|
//Generates the jni interface sources
|
||||||
|
tasks.named("javacpp${thisChip.capitalize()}SupportBuildParser").configure {
|
||||||
|
dependsOn "javacpp${thisChip.capitalize()}SupportBuildCommand"
|
||||||
|
}
|
||||||
|
//Compiles the c++ and cuda sources
|
||||||
|
tasks.named("javacpp${thisChip.capitalize()}SupportBuildCommand").configure {
|
||||||
|
|
||||||
|
}
|
||||||
|
//Compile the generates jni interface (java portion)
|
||||||
|
tasks.named("compile${thisChip.capitalize()}Java").configure {
|
||||||
|
dependsOn "javacpp${thisChip.capitalize()}SupportBuildParser"
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.named("${thisChip}Jar").configure {
|
||||||
|
dependsOn "javacpp${thisChip.capitalize()}SupportBuildCompiler"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType(JavaCompile) {
|
tasks.withType(JavaCompile).configureEach {
|
||||||
// options.setCompilerArgs(Arrays.asList("-Xlint:unchecked"))
|
// options.setCompilerArgs(Arrays.asList("-Xlint:unchecked"))
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType(Javadoc) {
|
tasks.withType(Javadoc).configureEach {
|
||||||
options.addStringOption('Xdoclint:none', '-quiet')
|
options.addStringOption('Xdoclint:none', '-quiet')
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
jar {
|
|
||||||
manifest {
|
|
||||||
attributes 'Class-Path': configurations.runtimeClasspath.collect { it.getName() }.join(' '),
|
|
||||||
'Implementation-Title': 'Brutex AI - Native Components',
|
|
||||||
'Implementation-Vendor': 'Brutex Network',
|
|
||||||
'Implementation-Version': archiveVersion,
|
|
||||||
'Specification-Title': 'Brutex AI - Native Components',
|
|
||||||
'Specification-Vendor': 'Brutex Network',
|
|
||||||
'Specification-Version': archiveVersion
|
|
||||||
}
|
|
||||||
//archiveClassifier = "${javacppPlatform}${javacppPlatformExtension}-${chip}"
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
javadoc {
|
javadoc {
|
||||||
dependsOn "javacppPomProperties"
|
dependsOn "javacppPomProperties"
|
||||||
failOnError = false
|
failOnError = false
|
||||||
|
@ -452,10 +431,6 @@ javadoc {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tasks.getByName("generatePomFileForMavenJavaPublication") {
|
tasks.getByName("generatePomFileForMavenJavaPublication") {
|
||||||
enabled = true
|
enabled = true
|
||||||
}
|
}
|
||||||
|
@ -465,32 +440,14 @@ javadoc {
|
||||||
|
|
||||||
artifacts {
|
artifacts {
|
||||||
//implementation(jar)
|
//implementation(jar)
|
||||||
chipList.each { thisChip ->
|
|
||||||
implementation(tasks.getByName("${thisChip}SupportJar"))
|
chipList.each { String thisChip ->
|
||||||
|
implementation tasks.getByName("${thisChip}Jar")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
artifacts {
|
|
||||||
archives jar
|
|
||||||
chipList.each { thisChip ->
|
|
||||||
archives tasks.getByName("${thisChip}SupportJar")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*/
|
|
||||||
/*
|
|
||||||
publishing {
|
|
||||||
publications {
|
|
||||||
mavenJava(MavenPublication) {
|
|
||||||
artifact jar
|
|
||||||
chipList.each { thisChip ->
|
|
||||||
artifact tasks.getByName("${thisChip}SupportJar")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
/*
|
/*
|
||||||
|
|
||||||
if( osdetector.os.startsWith("windows")) {
|
if( osdetector.os.startsWith("windows")) {
|
||||||
|
@ -516,48 +473,6 @@ if( osdetector.os.startsWith("windows")) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
task printDeps {
|
|
||||||
doLast {
|
|
||||||
configurations.apiElements.dependencies.each { dep ->
|
|
||||||
println "${dep.group} - ${dep.name} - ${dep.version}"
|
|
||||||
dep.artifacts.each { art ->
|
|
||||||
println " ${art.extension} - ${art.classifier}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
def pomClosure = {
|
|
||||||
name = 'Brutex AI - Native Components'
|
|
||||||
delegate.description = 'Underlying native components for the Brutex AI deeplearning framework for Java'
|
|
||||||
url = 'https://ai.brutex.net'
|
|
||||||
licenses {
|
|
||||||
license {
|
|
||||||
name = 'Apache License, Version 2.0'
|
|
||||||
url = 'http://www.apache.org/licenses/LICENSE-2.0'
|
|
||||||
distribution = 'repo'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
developers {
|
|
||||||
developer {
|
|
||||||
id = 'irnbrux'
|
|
||||||
name = 'Brian Rosenberger'
|
|
||||||
email = 'bru@brutex.de'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scm {
|
|
||||||
url = 'https://brutex.net/svn/'
|
|
||||||
connection = 'scm:svn:https://brutex.net/svn/bruai4j/'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
//tasks.getByName("publishMavenJavaPublicationToOSSRHRepository") { MavenRemotePublisher pub ->
|
|
||||||
// logger.quiet(pub.dump());
|
|
||||||
//}
|
|
||||||
|
|
||||||
signing {
|
signing {
|
||||||
useGpgCmd()
|
useGpgCmd()
|
||||||
if (!version.endsWith('SNAPSHOT')) {
|
if (!version.endsWith('SNAPSHOT')) {
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
# ******************************************************************************/
|
# ******************************************************************************/
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
#env
|
#env
|
||||||
|
|
||||||
set -eu
|
set -eu
|
||||||
|
|
|
@ -127,7 +127,7 @@ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel")
|
||||||
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -O3 -fp-model fast")
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -O3 -fp-model fast")
|
||||||
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
||||||
# using Visual Studio C++
|
# using Visual Studio C++
|
||||||
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} /Ox")
|
||||||
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
# using GCC
|
# using GCC
|
||||||
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -fmax-errors=2 -fdiagnostics-show-caret ")
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -fmax-errors=2 -fdiagnostics-show-caret ")
|
||||||
|
@ -161,15 +161,10 @@ if(HAVE_ARMCOMPUTE)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ops/declarable/platform/armcompute/*.cpp ops/declarable/platform/armcompute/*.h)
|
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ops/declarable/platform/armcompute/*.cpp ops/declarable/platform/armcompute/*.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(SD_CUDA)
|
|
||||||
message("Build cublas")
|
|
||||||
if(NOT DEFINED ${CMAKE_CUDA_ARCHITECTURES})
|
|
||||||
set(CMAKE_CUDA_ARCHITECTURES 75)
|
|
||||||
endif()
|
|
||||||
message(STATUS "CUDA architectures set to ${CMAKE_CUDA_ARCHITECTURES}")
|
|
||||||
|
|
||||||
find_package(CUDAToolkit)
|
if(SD_CUDA)
|
||||||
enable_language(CUDA)
|
#find_package(CUDAToolkit)
|
||||||
|
#enable_language(CUDA)
|
||||||
|
|
||||||
set(CMAKE_CUDA_STANDARD 17)
|
set(CMAKE_CUDA_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD 14)
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
|
@ -178,6 +173,9 @@ if(SD_CUDA)
|
||||||
#Enable features prio C++17
|
#Enable features prio C++17
|
||||||
add_definitions(-D_HAS_AUTO_PTR_ETC=1)
|
add_definitions(-D_HAS_AUTO_PTR_ETC=1)
|
||||||
|
|
||||||
|
set(CMAKE_CUDA_RUNTIME_LIBRARY "shared")
|
||||||
|
set(CMAKE_CUDA_ARCHITECTURES "61") #set(CMAKE_CUDA_ARCHITECTURES "62;75")
|
||||||
|
|
||||||
#This basically kills instrinsic activated through SD_F16C=true
|
#This basically kills instrinsic activated through SD_F16C=true
|
||||||
#if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
#if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
||||||
# set (CMAKE_CXX_FLAGS "")
|
# set (CMAKE_CXX_FLAGS "")
|
||||||
|
@ -205,47 +203,29 @@ if(SD_CUDA)
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC")
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(WIN32)
|
|
||||||
message("In windows, setting cublas library and cusolver library")
|
|
||||||
if(NOT DEFINED CUDA_cublas_LIBRARY)
|
|
||||||
set(CUDA_cublas_LIBRARY ${CUDA_HOME}/lib/x64/cublas.lib)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(NOT DEFINED CUDA_cusolver_LIBRARY)
|
# if(WIN32)
|
||||||
set(CUDA_cusolver_LIBRARY ${CUDA_HOME}/lib/x64/cusolver.lib)
|
# message("In windows, setting cublas library and cusolver library")
|
||||||
endif()
|
# if(NOT DEFINED CUDA_cublas_LIBRARY)
|
||||||
endif()
|
# set(CUDA_cublas_LIBRARY ${CUDA_HOME}/lib/x64/cublas.lib)
|
||||||
|
# endif()
|
||||||
|
|
||||||
#
|
# if(NOT DEFINED CUDA_cusolver_LIBRARY)
|
||||||
#string( TOLOWER "${COMPUTE}" COMPUTE_CMP )
|
# set(CUDA_cusolver_LIBRARY ${CUDA_HOME}/lib/x64/cusolver.lib)
|
||||||
# if ("${COMPUTE_CMP}" STREQUAL "all")
|
# endif()
|
||||||
# CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Common")
|
|
||||||
# elseif("${COMPUTE_CMP}" STREQUAL "auto")
|
|
||||||
# CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Auto")
|
|
||||||
# elseif(COMPUTE_CMP MATCHES "^[0-9]+$")
|
|
||||||
# #matches USER COMPUTE old way
|
|
||||||
#set(CUDA_ARCH_FLAGS "-gencode arch=compute_${COMPUTE},code=sm_${COMPUTE} ")
|
|
||||||
# else()
|
|
||||||
# #matches numbers NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
|
|
||||||
# #NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal
|
|
||||||
# #NUM: 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 et cetera
|
|
||||||
# CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "${COMPUTE}")
|
|
||||||
# endif()
|
# endif()
|
||||||
# list to spaces
|
|
||||||
#string (REPLACE ";" " " CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}")
|
|
||||||
|
|
||||||
#set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all ")
|
#set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all ")
|
||||||
set(CMAKE_CUDA_ARCHITECTURES OFF)
|
|
||||||
#set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-code \"arch=compute_53,code=[compute_53,sm_53]\" " )
|
#set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-code \"arch=compute_53,code=[compute_53,sm_53]\" " )
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-code \"arch=compute_61,code=[compute_61,sm_61]\" " )
|
#set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-code \"arch=compute_61,code=[compute_61,sm_61]\" " )
|
||||||
#set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-code \"arch=compute_75,code=[compute_75,sm_75]\" " )
|
#set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-code \"arch=compute_75,code=[compute_75,sm_75]\" " )
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda ")
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda ")
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr ")
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr ")
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=11 -w --cudart=static -Xfatbin -compress-all")
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=12 -w -Xfatbin -compress-all")
|
||||||
|
set(CUDAHOSTCXX "${CMAKE_CXX_COMPILER}")
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/EHsc")
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/EHsc -Xcompiler=/bigobj")
|
||||||
endif()
|
endif()
|
||||||
#set(GPU_ARCH)
|
|
||||||
|
|
||||||
message("CMAKE_CUDA_FLAGS = ${CMAKE_CUDA_FLAGS}")
|
message("CMAKE_CUDA_FLAGS = ${CMAKE_CUDA_FLAGS}")
|
||||||
message("CMAKE_CXX_FLAGS = ${CMAKE_CXX_FLAGS}")
|
message("CMAKE_CXX_FLAGS = ${CMAKE_CXX_FLAGS}")
|
||||||
|
@ -255,6 +235,9 @@ if(SD_CUDA)
|
||||||
message("CUDA_NVCC_FLAGS = ${CUDA_NVCC_FLAGS}")
|
message("CUDA_NVCC_FLAGS = ${CUDA_NVCC_FLAGS}")
|
||||||
message("CUDA_PROPAGATE_HOST_FLAGS = ${CUDA_PROPAGATE_HOST_FLAGS}")
|
message("CUDA_PROPAGATE_HOST_FLAGS = ${CUDA_PROPAGATE_HOST_FLAGS}")
|
||||||
message("CUDA_ARCH_FLAGS = ${CUDA_ARCH_FLAGS}")
|
message("CUDA_ARCH_FLAGS = ${CUDA_ARCH_FLAGS}")
|
||||||
|
message("CUDAHOSTCXX = ${CUDAHOSTCXX}")
|
||||||
|
message("CMAKE_CUDA_ARCHITECTURES = ${CMAKE_CUDA_ARCHITECTURES}")
|
||||||
|
message("CMAKE_CUDA_RUNTIME_LIBRARY = ${CMAKE_CUDA_RUNTIME_LIBRARY}")
|
||||||
|
|
||||||
file(GLOB_RECURSE PERF_SOURCES false performance/*.cpp performance/*.h)
|
file(GLOB_RECURSE PERF_SOURCES false performance/*.cpp performance/*.h)
|
||||||
file(GLOB_RECURSE EXCEPTIONS_SOURCES false exceptions/*.cpp exceptions/*.h)
|
file(GLOB_RECURSE EXCEPTIONS_SOURCES false exceptions/*.cpp exceptions/*.h)
|
||||||
|
@ -301,33 +284,36 @@ if(SD_CUDA)
|
||||||
|
|
||||||
# build shared library by default or when it's explicitly requested
|
# build shared library by default or when it's explicitly requested
|
||||||
if(NOT SD_STATIC_LIB OR SD_SHARED_LIB)
|
if(NOT SD_STATIC_LIB OR SD_SHARED_LIB)
|
||||||
|
message("Will build a shared library '${SD_LIBRARY_NAME}'.")
|
||||||
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:samediff_obj>)
|
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:samediff_obj>)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (SD_STATIC_LIB AND SD_SHARED_LIB)
|
if (SD_STATIC_LIB AND SD_SHARED_LIB)
|
||||||
# if both static and shared library are going to be built - static library will have special suffix
|
# if both static and shared library are going to be built - static library will have special suffix
|
||||||
|
message("Will build a static library '${SD_LIBRARY_NAME}static'.")
|
||||||
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:samediff_obj>)
|
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
#set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RUNTIME_LIBRARY}$<$<CONFIG:Debug>:Debug>")
|
||||||
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
||||||
elseif(SD_STATIC_LIB)
|
elseif(SD_STATIC_LIB)
|
||||||
# if we only build static library - use this name
|
# if we only build static library - use this name
|
||||||
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
#set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RUNTIME_LIBRARY}$<$<CONFIG:Debug>:Debug>")
|
||||||
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts
|
# on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts
|
||||||
set_property(TARGET samediff_obj PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
#set_property(TARGET samediff_obj PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RUNTIME_LIBRARY}$<$<CONFIG:Debug>:Debug>")
|
||||||
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
#set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RUNTIME_LIBRARY}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
|
||||||
# Done by nvcc as default on windows
|
# Done by nvcc as default on windows
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
message("CUDA on Windows: enabling /EHsc")
|
message("CUDA on Windows: enabling /EHsc and /bigobj")
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
#target_link_libraries(${SD_LIBRARY_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN})
|
||||||
target_link_libraries(${SD_LIBRARY_NAME} CUDA::cudart CUDA::cublas CUDA::cusolver ${CUDNN} ${MKLDNN})
|
target_link_libraries(${SD_LIBRARY_NAME} CUDA::cudart CUDA::cublas CUDA::cusolver CUDA::cublasLt Threads::Threads ${CUDNN} ${MKLDNN})
|
||||||
|
#target_link_libraries(${SD_LIBRARY_NAME} ${CUDNN} ${MKLDNN})
|
||||||
|
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda/${SD_EXTENSION})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda/${SD_EXTENSION})
|
||||||
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
|
@ -439,13 +425,13 @@ elseif(SD_CPU)
|
||||||
# if both static and shared library are going to be built - static library will have special suffix
|
# if both static and shared library are going to be built - static library will have special suffix
|
||||||
message("Adding a static library for ${SD_LIBRARY_NAME} as ${SD_LIBRARY_NAME}static")
|
message("Adding a static library for ${SD_LIBRARY_NAME} as ${SD_LIBRARY_NAME}static")
|
||||||
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:samediff_obj>)
|
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
#set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
||||||
elseif(SD_STATIC_LIB)
|
elseif(SD_STATIC_LIB)
|
||||||
# if we only build static library - use this name
|
# if we only build static library - use this name
|
||||||
message(Only building a static library for ${SD_LIBRARY_NAME})
|
message(Only building a static library for ${SD_LIBRARY_NAME})
|
||||||
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
#set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
@ -462,13 +448,13 @@ elseif(SD_CPU)
|
||||||
|
|
||||||
#This breaks the build. Normally you want to run tests anyways.
|
#This breaks the build. Normally you want to run tests anyways.
|
||||||
if(NOT "$ENV{CLION_IDE}")
|
if(NOT "$ENV{CLION_IDE}")
|
||||||
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(${SD_LIBRARY_NAME} Threads::Threads ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
||||||
message(STATUS "Building minifier...")
|
message(STATUS "Building minifier...")
|
||||||
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
||||||
target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(minifier samediff_obj Threads::Threads ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
|
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
include_directories(../../../../../src/main/cpp/blas)
|
||||||
if(LINUX)
|
if(LINUX)
|
||||||
link_directories(/usr/local/lib)
|
link_directories(/usr/local/lib)
|
||||||
link_directories(/usr/lib)
|
link_directories(/usr/lib)
|
||||||
|
@ -21,10 +21,18 @@ if(WIN32)
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
|
||||||
|
set(THREADS_PREFER_PTHREAD_FLAG TRUE)
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
if (SD_CUDA)
|
if (SD_CUDA)
|
||||||
find_package(CUDA)
|
find_package(CUDAToolkit 12.2 REQUIRED)
|
||||||
message("Tests CUDA include directory: ${CUDA_INCLUDE_DIRS}")
|
enable_language(CUDA)
|
||||||
include_directories(${CUDA_INCLUDE_DIRS})
|
|
||||||
|
set(CMAKE_CUDA_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
|
message("Tests CUDA include directory: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||||
|
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
||||||
add_definitions(-D__CUDABLAS__=true)
|
add_definitions(-D__CUDABLAS__=true)
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
|
@ -34,17 +42,14 @@ if (SD_CUDA)
|
||||||
|
|
||||||
string( TOLOWER "${COMPUTE}" COMPUTE_CMP )
|
string( TOLOWER "${COMPUTE}" COMPUTE_CMP )
|
||||||
if ("${COMPUTE_CMP}" STREQUAL "all")
|
if ("${COMPUTE_CMP}" STREQUAL "all")
|
||||||
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Common")
|
set(CMAKE_CUDA_ARCHITECTURES "all")
|
||||||
elseif("${COMPUTE_CMP}" STREQUAL "auto")
|
elseif("${COMPUTE_CMP}" STREQUAL "auto")
|
||||||
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Auto")
|
set(CMAKE_CUDA_ARCHITECTURES "all-major")
|
||||||
elseif(COMPUTE_CMP MATCHES "^[0-9]+$")
|
|
||||||
#matches USER COMPUTE old way
|
|
||||||
set(CUDA_ARCH_FLAGS "-gencode arch=compute_${COMPUTE},code=sm_${COMPUTE} ")
|
|
||||||
else()
|
else()
|
||||||
#matches numbers NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
|
#matches numbers NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
|
||||||
#NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal
|
#NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal
|
||||||
#NUM: 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 et cetera
|
#NUM: 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 et cetera
|
||||||
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "${COMPUTE}")
|
set(CMAKE_CUDA_ARCHITECTURES "all")
|
||||||
endif()
|
endif()
|
||||||
# list to spaces
|
# list to spaces
|
||||||
string (REPLACE ";" " " CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}")
|
string (REPLACE ";" " " CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}")
|
||||||
|
@ -149,7 +154,7 @@ if (SD_CPU)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_executable(runtests ${TEST_SOURCES})
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main)
|
target_link_libraries(runtests samediff_obj Threads::Threads ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main)
|
||||||
elseif(SD_CUDA)
|
elseif(SD_CUDA)
|
||||||
|
|
||||||
add_executable(runtests ${TEST_SOURCES})
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
|
@ -167,5 +172,5 @@ elseif(SD_CUDA)
|
||||||
message("CUDNN library: ${CUDNN}")
|
message("CUDNN library: ${CUDNN}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(runtests samediff_obj ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN} gtest gtest_main)
|
target_link_libraries(runtests samediff_obj Threads::Threads ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN} gtest gtest_main)
|
||||||
endif()
|
endif()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue