Merge branch 'master' into sa_tvm

master
Adam Gibson 2021-03-09 07:53:01 +09:00 committed by GitHub
commit ad12d2148d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
420 changed files with 8188 additions and 30872 deletions

View File

@ -0,0 +1,11 @@
name: Download dl4j test resources
runs:
using: composite
steps:
- name: Initial install
shell: bash
run: |
wget https://github.com/KonduitAI/dl4j-test-resources/archive/master.zip && unzip master.zip
cd dl4j-test-resources-master
mvn clean install -DskipTests
echo "Extracted test resources"

View File

@ -0,0 +1,12 @@
name: Download dl4j test resources
runs:
using: composite
steps:
- name: Initial install
shell: cmd
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
wget https://github.com/KonduitAI/dl4j-test-resources/archive/master.zip && unzip master.zip
cd dl4j-test-resources-master
mvn clean install -DskipTests
echo "Extracted test resources"

View File

@ -0,0 +1,12 @@
name: Download dl4j test resources
runs:
using: composite
steps:
- name: Initial install
shell: bash
run: |
sudo apt install git gcc-8-aarch64-linux-gnu g++-8-aarch64-linux-gnu libc6-armel-cross libc6-dev-armel-cross binutils-arm-linux-gnueabi libncurses5-dev build-essential bison flex libssl-dev bc \
gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf crossbuild-essential-arm64
mkdir -p /opt/raspberrypi && \
cd /opt/raspberrypi && \
git clone git://github.com/raspberrypi/tools.git

View File

@ -0,0 +1,16 @@
name: Install protobuf linux
runs:
using: composite
steps:
- name: Install protobuf linux
shell: bash
run: |
curl -fsSL https://github.com/google/protobuf/releases/download/v3.5.1/protobuf-cpp-3.5.1.tar.gz \
| tar xz && \
cd protobuf-3.5.1 && \
./configure --prefix=/opt/protobuf && \
make -j2 && \
make install && \
cd .. && \
rm -rf protobuf-3.5.1
echo "/opt/protobuf/bin" >> $GITHUB_PATH

View File

@ -0,0 +1,10 @@
name: Setup for msys2
runs:
using: composite
steps:
- name: Initial install
shell: cmd
run: |
C:\msys64\usr\bin\bash -lc "pacman -S --needed --noconfirm base-devel git tar pkg-config unzip p7zip zip autoconf autoconf-archive automake make patch gnupg"
C:\msys64\usr\bin\bash -lc "pacman -S --needed --noconfirm mingw-w64-x86_64-nasm mingw-w64-x86_64-toolchain mingw-w64-x86_64-libtool mingw-w64-x86_64-gcc mingw-w64-i686-gcc mingw-w64-x86_64-gcc-fortran mingw-w64-i686-gcc-fortran mingw-w64-x86_64-libwinpthread-git mingw-w64-i686-libwinpthread-git mingw-w64-x86_64-SDL mingw-w64-i686-SDL mingw-w64-x86_64-ragel"
echo "C:\msys64\usr\bin" >> $GITHUB_PATH

View File

@ -0,0 +1,9 @@
name: Publish to github packages
runs:
using: composite
steps:
- name: Publish to GitHub Packages
run: mvn -Pgithub --batch-mode deploy
shell: bash
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -0,0 +1,48 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
android-x86_64:
runs-on: ubuntu-18.04
steps:
- uses: AutoModality/action-clean@v1
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: nttld/setup-ndk@v1
id: setup-ndk
with:
ndk-version: r18b
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build on linux-x86_64
env:
ANDROID_NDK: ${{ steps.setup-ndk.outputs.ndk-path }}
LIBND4J_HOME: "${GITHUB_WORKSPACE}/libnd4j"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
echo "Verifying programs on path. Path is $PATH"
echo "Path post update is $PATH. Maven is at `which mvn` cmake is at `which cmake` protoc is at `which protoc`"
mvn --version
cmake --version
protoc --version
clang --version
mvn -X -Dorg.bytedeco.javacpp.logger.debug=true -Possrh -pl ":nd4j-native,:libnd4j" --also-make \
-Djavacpp.platform=android-x86_64 \
-Dlibnd4j.platform=android-x86_64 -Dlibnd4j.chip=cpu \
--batch-mode clean deploy -DskipTests

View File

@ -0,0 +1,44 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
android-arm32:
runs-on: ubuntu-18.04
steps:
- uses: AutoModality/action-clean@v1
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build on android-arm32
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DEPLOY: 1
BUILD_USING_MAVEN: 1
TARGET_OS: android
CURRENT_TARGET: arm32
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
mvn --version
cmake --version
protoc --version
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh

View File

@ -0,0 +1,44 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
android-arm64:
runs-on: ubuntu-18.04
steps:
- uses: AutoModality/action-clean@v1
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build on android-arm64
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DEPLOY: 1
BUILD_USING_MAVEN: 1
TARGET_OS: android
CURRENT_TARGET: arm64
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
mvn --version
cmake --version
protoc --version
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh

View File

@ -0,0 +1,44 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
linux-arm32:
runs-on: ubuntu-18.04
steps:
- uses: AutoModality/action-clean@v1
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build on linux-arm32
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DEPLOY: 1
BUILD_USING_MAVEN: 1
TARGET_OS: linux
CURRENT_TARGET: arm32
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
mvn --version
cmake --version
protoc --version
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh

View File

@ -0,0 +1,41 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
linux-arm64:
runs-on: ubuntu-18.04
steps:
- uses: AutoModality/action-clean@v1
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build on linux-arm64
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DEPLOY: 1
BUILD_USING_MAVEN: 1
TARGET_OS: linux
CURRENT_TARGET: arm64
PUBLISH_TO: ossrh
run: |
mvn --version
cmake --version
protoc --version
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh

View File

@ -0,0 +1,55 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
linux-x86_64-cuda_11-0:
runs-on: ubuntu-18.04
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- name: Maximize build space
uses: easimon/maximize-build-space@master
with:
root-reserve-mb: 512
swap-size-mb: 8192
remove-dotnet: 'true'
remove-haskell: 'true'
- uses: actions/checkout@v2
- uses: konduitai/cuda-install/.github/actions/install-cuda-ubuntu@master
env:
cuda: 11.0.167
GCC: 9
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build cuda
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
export PATH="/usr/local/cuda-11.0/bin:$PATH"
mvn --version
cmake --version
protoc --version
nvcc --version
sudo apt-get autoremove
sudo apt-get clean
mvn -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -Dlibnd4j.chip=cuda -pl ":nd4j-cuda-11.0,:deeplearning4j-cuda-11.0,:libnd4j" --also-make -Pcuda clean --batch-mode deploy -DskipTests

View File

@ -0,0 +1,52 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
linux-x86_64-cuda-11-2:
runs-on: ubuntu-18.04
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- name: Maximize build space
uses: easimon/maximize-build-space@master
with:
root-reserve-mb: 512
swap-size-mb: 8192
remove-dotnet: 'true'
remove-haskell: 'true'
- uses: actions/checkout@v2
- uses: konduitai/cuda-install/.github/actions/install-cuda-ubuntu@master
env:
cuda: 11.2.1_461
GCC: 9
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Run cuda compilation on linux-x86_64
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
export PATH="/usr/local/cuda-11.2/bin:$PATH"
nvcc --version
mvn --version
cmake --version
protoc --version
sudo apt-get autoremove
sudo apt-get clean
bash ./change-cuda-versions.sh 11.2
mvn -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -pl ":nd4j-cuda-11.2,:deeplearning4j-cuda-11.2,:libnd4j" --also-make -Dlibnd4j.chip=cuda --batch-mode deploy -DskipTests

View File

@ -0,0 +1,41 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
linux-x86_64:
runs-on: ubuntu-18.04
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build on linux-x86_64
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
mvn --version
cmake --version
protoc --version
sudo apt-get autoremove
sudo apt-get clean
mvn -X -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.chip=cpu -Pcpu --batch-mode deploy -DskipTests

34
.github/workflows/build-deploy-mac.yml vendored Normal file
View File

@ -0,0 +1,34 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
mac-x86_64:
runs-on: macos-10.15
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Build and install
shell: bash
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
mvn -Possrh -Djavacpp.platform=macosx-x86_64 -Djavacpp.platform=macosx-x86_64 -pl ":nd4j-native,:libnd4j" --also-make -Dlibnd4j.platform=macosx-x86_64 -Dlibnd4j.chip=cpu clean --batch-mode deploy -DskipTests

View File

@ -0,0 +1,46 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
windows-x86_64-cuda-11-0:
runs-on: windows-2019
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Use existing msys2 to setup environment
uses: ./.github/actions/msys2-base-setup
- uses: konduitai/cuda-install/.github/actions/install-cuda-windows@master
env:
cuda: 11.0.167
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Run windows build
shell: cmd
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
set MSYSTEM=MINGW64
set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0"
which cmake
dir "%CUDA_PATH%"
dir "%CUDA_PATH%\lib"
set "PATH=C:\msys64\usr\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\lib\x64;%PATH%"
echo "Running cuda build"
mvn -Possrh -Djavacpp.platform=windows-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -Djavacpp.platform=windows-x86_64 -pl ":nd4j-cuda-11.0,:deeplearning4j-cuda-11.0,:libnd4j" --also-make -Dlibnd4j.platform=windows-x86_64 -Pcuda -Dlibnd4j.chip=cuda -Pcuda clean --batch-mode deploy -DskipTests

View File

@ -0,0 +1,48 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
windows-x86_64-cuda-11-2:
runs-on: windows-2019
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Use existing msys2 to setup environment
uses: ./.github/actions/msys2-base-setup
- uses: konduitai/cuda-install/.github/actions/install-cuda-windows@master
env:
cuda: 11.2.1
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Run cuda build
shell: cmd
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
set MSYSTEM=MINGW64
set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2"
dir "%CUDA_PATH%"
dir "%CUDA_PATH%\lib"
which cmake
set "PATH=C:\msys64\usr\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\lib\x64;%PATH%"
echo "Running cuda build"
bash ./change-cuda-versions.sh 11.2
sudo apt-get autoremove
sudo apt-get clean
mvn -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -Djavacpp.platform=windows-x86_64 -pl ":nd4j-cuda-11.2,:libnd4j,:deeplearning4j-cuda-11.2" --also-make -Dlibnd4j.platform=windows-x86_64 -Pcuda -Dlibnd4j.chip=cuda -Pcuda clean --batch-mode deploy -DskipTests

View File

@ -0,0 +1,35 @@
on:
schedule:
- cron: "0 */12 * * *"
jobs:
windows-x86_64:
runs-on: windows-2019
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/msys2-base-setup
- name: Set up Java for publishing to GitHub Packages
uses: actions/setup-java@v1
with:
java-version: 1.8
server-id: sonatype-nexus-snapshots
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Run windows cpu build
shell: cmd
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PUBLISH_TO: ossrh
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
run: |
set MSYSTEM=MINGW64
set "PATH=C:\msys64\usr\bin;%PATH%"
mvn -Possrh -Djavacpp.platform=windows-x86_64 -pl ":nd4j-native,:libnd4j" --also-make -Dlibnd4j.platform=windows-x86_64 -Dlibnd4j.chip=cpu deploy -DskipTests

View File

@ -0,0 +1,52 @@
on:
push:
jobs:
linux-x86_64:
runs-on: ubuntu-18.04
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- uses: ./.github/actions/download-dl4j-test-resources-linux
- name: Run tests on linux-x86_64
shell: bash
run: |
mvn --version
cmake --version
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/msys2-base-setup
- uses: ./.github/actions/download-dl4j-test-resources-windows
- name: Run tests
shell: cmd
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mac-x86_64:
runs-on: macos-10.15
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/download-dl4j-test-resources-linux
- name: Install and run tests
shell: bash
env:
VERBOSE: 1
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test

View File

@ -0,0 +1,52 @@
on:
push:
jobs:
linux-x86_64:
runs-on: ubuntu-18.04
steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: ./.github/actions/install-protobuf-linux
- uses: ./.github/actions/download-dl4j-test-resources-linux
- name: Run tests on linux-x86_64
shell: bash
run: |
mvn --version
cmake --version
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/msys2-base-setup
- uses: ./.github/actions/download-dl4j-test-resources-windows
- name: Run tests
shell: cmd
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mac-x86_64:
runs-on: macos-10.15
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/download-dl4j-test-resources-linux
- name: Install and run tests
shell: bash
env:
VERBOSE: 1
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test

2
.gitignore vendored
View File

@ -79,3 +79,5 @@ libnd4j/cmake*
#vim #vim
*.swp *.swp
*.dll

View File

@ -83,4 +83,8 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
} }
} }
@Override
public long getTimeoutMilliseconds() {
return Long.MAX_VALUE;
}
} }

View File

@ -28,6 +28,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -60,9 +61,10 @@ public class WritableTest extends BaseND4JTest {
public void testBytesWritableIndexing() { public void testBytesWritableIndexing() {
byte[] doubleWrite = new byte[16]; byte[] doubleWrite = new byte[16];
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite); ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
Buffer buffer = (Buffer) wrapped;
wrapped.putDouble(1.0); wrapped.putDouble(1.0);
wrapped.putDouble(2.0); wrapped.putDouble(2.0);
wrapped.rewind(); buffer.rewind();
BytesWritable byteWritable = new BytesWritable(doubleWrite); BytesWritable byteWritable = new BytesWritable(doubleWrite);
assertEquals(2,byteWritable.getDouble(1),1e-1); assertEquals(2,byteWritable.getDouble(1),1e-1);
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2}); DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});

View File

@ -1,77 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-data</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-data-audio</artifactId>
<name>datavec-data-audio</name>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacpp</artifactId>
<version>${javacpp.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv</artifactId>
<version>${javacv.version}</version>
</dependency>
<dependency>
<groupId>com.github.wendykierp</groupId>
<artifactId>JTransforms</artifactId>
<version>${jtransforms.version}</version>
<classifier>with-dependencies</classifier>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!--
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>ffmpeg-platform</artifactId>
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
</dependency>
-->
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,329 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import org.datavec.audio.extension.NormalizedSampleAmplitudes;
import org.datavec.audio.extension.Spectrogram;
import org.datavec.audio.fingerprint.FingerprintManager;
import org.datavec.audio.fingerprint.FingerprintSimilarity;
import org.datavec.audio.fingerprint.FingerprintSimilarityComputer;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
public class Wave implements Serializable {
private static final long serialVersionUID = 1L;
private WaveHeader waveHeader;
private byte[] data; // little endian
private byte[] fingerprint;
/**
* Constructor
*
*/
public Wave() {
this.waveHeader = new WaveHeader();
this.data = new byte[0];
}
/**
* Constructor
*
* @param filename
* Wave file
*/
public Wave(String filename) {
try {
InputStream inputStream = new FileInputStream(filename);
initWaveWithInputStream(inputStream);
inputStream.close();
} catch (IOException e) {
System.out.println(e.toString());
}
}
/**
* Constructor
*
* @param inputStream
* Wave file input stream
*/
public Wave(InputStream inputStream) {
initWaveWithInputStream(inputStream);
}
/**
* Constructor
*
* @param waveHeader
* @param data
*/
public Wave(WaveHeader waveHeader, byte[] data) {
this.waveHeader = waveHeader;
this.data = data;
}
private void initWaveWithInputStream(InputStream inputStream) {
// reads the first 44 bytes for header
waveHeader = new WaveHeader(inputStream);
if (waveHeader.isValid()) {
// load data
try {
data = new byte[inputStream.available()];
inputStream.read(data);
} catch (IOException e) {
System.err.println(e.toString());
}
// end load data
} else {
System.err.println("Invalid Wave Header");
}
}
/**
* Trim the wave data
*
* @param leftTrimNumberOfSample
* Number of sample trimmed from beginning
* @param rightTrimNumberOfSample
* Number of sample trimmed from ending
*/
public void trim(int leftTrimNumberOfSample, int rightTrimNumberOfSample) {
long chunkSize = waveHeader.getChunkSize();
long subChunk2Size = waveHeader.getSubChunk2Size();
long totalTrimmed = leftTrimNumberOfSample + rightTrimNumberOfSample;
if (totalTrimmed > subChunk2Size) {
leftTrimNumberOfSample = (int) subChunk2Size;
}
// update wav info
chunkSize -= totalTrimmed;
subChunk2Size -= totalTrimmed;
if (chunkSize >= 0 && subChunk2Size >= 0) {
waveHeader.setChunkSize(chunkSize);
waveHeader.setSubChunk2Size(subChunk2Size);
byte[] trimmedData = new byte[(int) subChunk2Size];
System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size);
data = trimmedData;
} else {
System.err.println("Trim error: Negative length");
}
}
/**
* Trim the wave data from beginning
*
* @param numberOfSample
* numberOfSample trimmed from beginning
*/
public void leftTrim(int numberOfSample) {
trim(numberOfSample, 0);
}
/**
* Trim the wave data from ending
*
* @param numberOfSample
* numberOfSample trimmed from ending
*/
public void rightTrim(int numberOfSample) {
trim(0, numberOfSample);
}
/**
* Trim the wave data
*
* @param leftTrimSecond
* Seconds trimmed from beginning
* @param rightTrimSecond
* Seconds trimmed from ending
*/
public void trim(double leftTrimSecond, double rightTrimSecond) {
int sampleRate = waveHeader.getSampleRate();
int bitsPerSample = waveHeader.getBitsPerSample();
int channels = waveHeader.getChannels();
int leftTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * leftTrimSecond);
int rightTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * rightTrimSecond);
trim(leftTrimNumberOfSample, rightTrimNumberOfSample);
}
/**
* Trim the wave data from beginning
*
* @param second
* Seconds trimmed from beginning
*/
public void leftTrim(double second) {
trim(second, 0);
}
/**
* Trim the wave data from ending
*
* @param second
* Seconds trimmed from ending
*/
public void rightTrim(double second) {
trim(0, second);
}
/**
* Get the wave header
*
* @return waveHeader
*/
public WaveHeader getWaveHeader() {
return waveHeader;
}
/**
* Get the wave spectrogram
*
* @return spectrogram
*/
public Spectrogram getSpectrogram() {
return new Spectrogram(this);
}
/**
* Get the wave spectrogram
*
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
*
* @return spectrogram
*/
public Spectrogram getSpectrogram(int fftSampleSize, int overlapFactor) {
return new Spectrogram(this, fftSampleSize, overlapFactor);
}
/**
* Get the wave data in bytes
*
* @return wave data
*/
public byte[] getBytes() {
return data;
}
/**
* Data byte size of the wave excluding header size
*
* @return byte size of the wave
*/
public int size() {
return data.length;
}
/**
* Length of the wave in second
*
* @return length in second
*/
public float length() {
return (float) waveHeader.getSubChunk2Size() / waveHeader.getByteRate();
}
/**
* Timestamp of the wave length
*
* @return timestamp
*/
public String timestamp() {
float totalSeconds = this.length();
float second = totalSeconds % 60;
int minute = (int) totalSeconds / 60 % 60;
int hour = (int) (totalSeconds / 3600);
StringBuilder sb = new StringBuilder();
if (hour > 0) {
sb.append(hour + ":");
}
if (minute > 0) {
sb.append(minute + ":");
}
sb.append(second);
return sb.toString();
}
/**
* Get the amplitudes of the wave samples (depends on the header)
*
* @return amplitudes array (signed 16-bit)
*/
public short[] getSampleAmplitudes() {
int bytePerSample = waveHeader.getBitsPerSample() / 8;
int numSamples = data.length / bytePerSample;
short[] amplitudes = new short[numSamples];
int pointer = 0;
for (int i = 0; i < numSamples; i++) {
short amplitude = 0;
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
// little endian
amplitude |= (short) ((data[pointer++] & 0xFF) << (byteNumber * 8));
}
amplitudes[i] = amplitude;
}
return amplitudes;
}
public String toString() {
StringBuilder sb = new StringBuilder(waveHeader.toString());
sb.append("\n");
sb.append("length: " + timestamp());
return sb.toString();
}
public double[] getNormalizedAmplitudes() {
NormalizedSampleAmplitudes amplitudes = new NormalizedSampleAmplitudes(this);
return amplitudes.getNormalizedAmplitudes();
}
public byte[] getFingerprint() {
if (fingerprint == null) {
FingerprintManager fingerprintManager = new FingerprintManager();
fingerprint = fingerprintManager.extractFingerprint(this);
}
return fingerprint;
}
public FingerprintSimilarity getFingerprintSimilarity(Wave wave) {
FingerprintSimilarityComputer fingerprintSimilarityComputer =
new FingerprintSimilarityComputer(this.getFingerprint(), wave.getFingerprint());
return fingerprintSimilarityComputer.getFingerprintsSimilarity();
}
}

View File

@ -1,98 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.FileOutputStream;
import java.io.IOException;
@Slf4j
public class WaveFileManager {
private Wave wave;
public WaveFileManager() {
wave = new Wave();
}
public WaveFileManager(Wave wave) {
setWave(wave);
}
/**
* Save the wave file
*
* @param filename
* filename to be saved
*
* @see Wave file saved
*/
public void saveWaveAsFile(String filename) {
WaveHeader waveHeader = wave.getWaveHeader();
int byteRate = waveHeader.getByteRate();
int audioFormat = waveHeader.getAudioFormat();
int sampleRate = waveHeader.getSampleRate();
int bitsPerSample = waveHeader.getBitsPerSample();
int channels = waveHeader.getChannels();
long chunkSize = waveHeader.getChunkSize();
long subChunk1Size = waveHeader.getSubChunk1Size();
long subChunk2Size = waveHeader.getSubChunk2Size();
int blockAlign = waveHeader.getBlockAlign();
try {
FileOutputStream fos = new FileOutputStream(filename);
fos.write(WaveHeader.RIFF_HEADER.getBytes());
// little endian
fos.write(new byte[] {(byte) (chunkSize), (byte) (chunkSize >> 8), (byte) (chunkSize >> 16),
(byte) (chunkSize >> 24)});
fos.write(WaveHeader.WAVE_HEADER.getBytes());
fos.write(WaveHeader.FMT_HEADER.getBytes());
fos.write(new byte[] {(byte) (subChunk1Size), (byte) (subChunk1Size >> 8), (byte) (subChunk1Size >> 16),
(byte) (subChunk1Size >> 24)});
fos.write(new byte[] {(byte) (audioFormat), (byte) (audioFormat >> 8)});
fos.write(new byte[] {(byte) (channels), (byte) (channels >> 8)});
fos.write(new byte[] {(byte) (sampleRate), (byte) (sampleRate >> 8), (byte) (sampleRate >> 16),
(byte) (sampleRate >> 24)});
fos.write(new byte[] {(byte) (byteRate), (byte) (byteRate >> 8), (byte) (byteRate >> 16),
(byte) (byteRate >> 24)});
fos.write(new byte[] {(byte) (blockAlign), (byte) (blockAlign >> 8)});
fos.write(new byte[] {(byte) (bitsPerSample), (byte) (bitsPerSample >> 8)});
fos.write(WaveHeader.DATA_HEADER.getBytes());
fos.write(new byte[] {(byte) (subChunk2Size), (byte) (subChunk2Size >> 8), (byte) (subChunk2Size >> 16),
(byte) (subChunk2Size >> 24)});
fos.write(wave.getBytes());
fos.close();
} catch (IOException e) {
log.error("",e);
}
}
public Wave getWave() {
return wave;
}
public void setWave(Wave wave) {
this.wave = wave;
}
}

View File

@ -1,281 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.io.InputStream;
@Slf4j
public class WaveHeader {
public static final String RIFF_HEADER = "RIFF";
public static final String WAVE_HEADER = "WAVE";
public static final String FMT_HEADER = "fmt ";
public static final String DATA_HEADER = "data";
public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header
private boolean valid;
private String chunkId; // 4 bytes
private long chunkSize; // unsigned 4 bytes, little endian
private String format; // 4 bytes
private String subChunk1Id; // 4 bytes
private long subChunk1Size; // unsigned 4 bytes, little endian
private int audioFormat; // unsigned 2 bytes, little endian
private int channels; // unsigned 2 bytes, little endian
private long sampleRate; // unsigned 4 bytes, little endian
private long byteRate; // unsigned 4 bytes, little endian
private int blockAlign; // unsigned 2 bytes, little endian
private int bitsPerSample; // unsigned 2 bytes, little endian
private String subChunk2Id; // 4 bytes
private long subChunk2Size; // unsigned 4 bytes, little endian
public WaveHeader() {
// init a 8k 16bit mono wav
chunkSize = 36;
subChunk1Size = 16;
audioFormat = 1;
channels = 1;
sampleRate = 8000;
byteRate = 16000;
blockAlign = 2;
bitsPerSample = 16;
subChunk2Size = 0;
valid = true;
}
public WaveHeader(InputStream inputStream) {
valid = loadHeader(inputStream);
}
private boolean loadHeader(InputStream inputStream) {
byte[] headerBuffer = new byte[HEADER_BYTE_LENGTH];
try {
inputStream.read(headerBuffer);
// read header
int pointer = 0;
chunkId = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++]});
// little endian
chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff << 24);
format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++]});
subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++], headerBuffer[pointer++]});
subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++], headerBuffer[pointer++]});
subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
// end read header
// the inputStream should be closed outside this method
// dis.close();
} catch (IOException e) {
log.error("",e);
return false;
}
if (bitsPerSample != 8 && bitsPerSample != 16) {
System.err.println("WaveHeader: only supports bitsPerSample 8 or 16");
return false;
}
// check the format is support
if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) {
return true;
} else {
System.err.println("WaveHeader: Unsupported header format");
}
return false;
}
public boolean isValid() {
return valid;
}
public String getChunkId() {
return chunkId;
}
public long getChunkSize() {
return chunkSize;
}
public String getFormat() {
return format;
}
public String getSubChunk1Id() {
return subChunk1Id;
}
public long getSubChunk1Size() {
return subChunk1Size;
}
public int getAudioFormat() {
return audioFormat;
}
public int getChannels() {
return channels;
}
public int getSampleRate() {
return (int) sampleRate;
}
public int getByteRate() {
return (int) byteRate;
}
public int getBlockAlign() {
return blockAlign;
}
public int getBitsPerSample() {
return bitsPerSample;
}
public String getSubChunk2Id() {
return subChunk2Id;
}
public long getSubChunk2Size() {
return subChunk2Size;
}
public void setSampleRate(int sampleRate) {
int newSubChunk2Size = (int) (this.subChunk2Size * sampleRate / this.sampleRate);
// if num bytes for each sample is even, the size of newSubChunk2Size also needed to be in even number
if ((bitsPerSample / 8) % 2 == 0) {
if (newSubChunk2Size % 2 != 0) {
newSubChunk2Size++;
}
}
this.sampleRate = sampleRate;
this.byteRate = sampleRate * bitsPerSample / 8;
this.chunkSize = newSubChunk2Size + 36;
this.subChunk2Size = newSubChunk2Size;
}
public void setChunkId(String chunkId) {
this.chunkId = chunkId;
}
public void setChunkSize(long chunkSize) {
this.chunkSize = chunkSize;
}
public void setFormat(String format) {
this.format = format;
}
public void setSubChunk1Id(String subChunk1Id) {
this.subChunk1Id = subChunk1Id;
}
public void setSubChunk1Size(long subChunk1Size) {
this.subChunk1Size = subChunk1Size;
}
public void setAudioFormat(int audioFormat) {
this.audioFormat = audioFormat;
}
public void setChannels(int channels) {
this.channels = channels;
}
public void setByteRate(long byteRate) {
this.byteRate = byteRate;
}
public void setBlockAlign(int blockAlign) {
this.blockAlign = blockAlign;
}
public void setBitsPerSample(int bitsPerSample) {
this.bitsPerSample = bitsPerSample;
}
public void setSubChunk2Id(String subChunk2Id) {
this.subChunk2Id = subChunk2Id;
}
public void setSubChunk2Size(long subChunk2Size) {
this.subChunk2Size = subChunk2Size;
}
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("chunkId: " + chunkId);
sb.append("\n");
sb.append("chunkSize: " + chunkSize);
sb.append("\n");
sb.append("format: " + format);
sb.append("\n");
sb.append("subChunk1Id: " + subChunk1Id);
sb.append("\n");
sb.append("subChunk1Size: " + subChunk1Size);
sb.append("\n");
sb.append("audioFormat: " + audioFormat);
sb.append("\n");
sb.append("channels: " + channels);
sb.append("\n");
sb.append("sampleRate: " + sampleRate);
sb.append("\n");
sb.append("byteRate: " + byteRate);
sb.append("\n");
sb.append("blockAlign: " + blockAlign);
sb.append("\n");
sb.append("bitsPerSample: " + bitsPerSample);
sb.append("\n");
sb.append("subChunk2Id: " + subChunk2Id);
sb.append("\n");
sb.append("subChunk2Size: " + subChunk2Size);
return sb.toString();
}
}

View File

@ -1,81 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
import org.jtransforms.fft.DoubleFFT_1D;
public class FastFourierTransform {
/**
* Get the frequency intensities
*
* @param amplitudes amplitudes of the signal. Format depends on value of complex
* @param complex if true, amplitudes is assumed to be complex interlaced (re = even, im = odd), if false amplitudes
* are assumed to be real valued.
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
*/
public double[] getMagnitudes(double[] amplitudes, boolean complex) {
final int sampleSize = amplitudes.length;
final int nrofFrequencyBins = sampleSize / 2;
// call the fft and transform the complex numbers
if (complex) {
DoubleFFT_1D fft = new DoubleFFT_1D(nrofFrequencyBins);
fft.complexForward(amplitudes);
} else {
DoubleFFT_1D fft = new DoubleFFT_1D(sampleSize);
fft.realForward(amplitudes);
// amplitudes[1] contains re[sampleSize/2] or im[(sampleSize-1) / 2] (depending on whether sampleSize is odd or even)
// Discard it as it is useless without the other part
// im part dc bin is always 0 for real input
amplitudes[1] = 0;
}
// end call the fft and transform the complex numbers
// even indexes (0,2,4,6,...) are real parts
// odd indexes (1,3,5,7,...) are img parts
double[] mag = new double[nrofFrequencyBins];
for (int i = 0; i < nrofFrequencyBins; i++) {
final int f = 2 * i;
mag[i] = Math.sqrt(amplitudes[f] * amplitudes[f] + amplitudes[f + 1] * amplitudes[f + 1]);
}
return mag;
}
/**
* Get the frequency intensities. Backwards compatible with previous versions w.r.t to number of frequency bins.
* Use getMagnitudes(amplitudes, true) to get all bins.
*
* @param amplitudes complex-valued signal to transform. Even indexes are real and odd indexes are img
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
*/
public double[] getMagnitudes(double[] amplitudes) {
double[] magnitudes = getMagnitudes(amplitudes, true);
double[] halfOfMagnitudes = new double[magnitudes.length/2];
System.arraycopy(magnitudes, 0,halfOfMagnitudes, 0, halfOfMagnitudes.length);
return halfOfMagnitudes;
}
}

View File

@ -1,66 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
public class LinearInterpolation {
public LinearInterpolation() {
}
/**
* Do interpolation on the samples according to the original and destinated sample rates
*
* @param oldSampleRate sample rate of the original samples
* @param newSampleRate sample rate of the interpolated samples
* @param samples original samples
* @return interpolated samples
*/
public short[] interpolate(int oldSampleRate, int newSampleRate, short[] samples) {
if (oldSampleRate == newSampleRate) {
return samples;
}
int newLength = Math.round(((float) samples.length / oldSampleRate * newSampleRate));
float lengthMultiplier = (float) newLength / samples.length;
short[] interpolatedSamples = new short[newLength];
// interpolate the value by the linear equation y=mx+c
for (int i = 0; i < newLength; i++) {
// get the nearest positions for the interpolated point
float currentPosition = i / lengthMultiplier;
int nearestLeftPosition = (int) currentPosition;
int nearestRightPosition = nearestLeftPosition + 1;
if (nearestRightPosition >= samples.length) {
nearestRightPosition = samples.length - 1;
}
float slope = samples[nearestRightPosition] - samples[nearestLeftPosition]; // delta x is 1
float positionFromLeft = currentPosition - nearestLeftPosition;
interpolatedSamples[i] = (short) (slope * positionFromLeft + samples[nearestLeftPosition]); // y=mx+c
}
return interpolatedSamples;
}
}

View File

@ -1,84 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
public class Resampler {
public Resampler() {}
/**
* Do resampling. Currently the amplitude is stored by short such that maximum bitsPerSample is 16 (bytePerSample is 2)
*
* @param sourceData The source data in bytes
* @param bitsPerSample How many bits represents one sample (currently supports max. bitsPerSample=16)
* @param sourceRate Sample rate of the source data
* @param targetRate Sample rate of the target data
* @return re-sampled data
*/
public byte[] reSample(byte[] sourceData, int bitsPerSample, int sourceRate, int targetRate) {
// make the bytes to amplitudes first
int bytePerSample = bitsPerSample / 8;
int numSamples = sourceData.length / bytePerSample;
short[] amplitudes = new short[numSamples]; // 16 bit, use a short to store
int pointer = 0;
for (int i = 0; i < numSamples; i++) {
short amplitude = 0;
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
// little endian
amplitude |= (short) ((sourceData[pointer++] & 0xFF) << (byteNumber * 8));
}
amplitudes[i] = amplitude;
}
// end make the amplitudes
// do interpolation
LinearInterpolation reSample = new LinearInterpolation();
short[] targetSample = reSample.interpolate(sourceRate, targetRate, amplitudes);
int targetLength = targetSample.length;
// end do interpolation
// TODO: Remove the high frequency signals with a digital filter, leaving a signal containing only half-sample-rated frequency information, but still sampled at a rate of target sample rate. Usually FIR is used
// end resample the amplitudes
// convert the amplitude to bytes
byte[] bytes;
if (bytePerSample == 1) {
bytes = new byte[targetLength];
for (int i = 0; i < targetLength; i++) {
bytes[i] = (byte) targetSample[i];
}
} else {
// suppose bytePerSample==2
bytes = new byte[targetLength * 2];
for (int i = 0; i < targetSample.length; i++) {
// little endian
bytes[i * 2] = (byte) (targetSample[i] & 0xff);
bytes[i * 2 + 1] = (byte) ((targetSample[i] >> 8) & 0xff);
}
}
// end convert the amplitude to bytes
return bytes;
}
}

View File

@ -1,95 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
public class WindowFunction {
public static final int RECTANGULAR = 0;
public static final int BARTLETT = 1;
public static final int HANNING = 2;
public static final int HAMMING = 3;
public static final int BLACKMAN = 4;
int windowType = 0; // defaults to rectangular window
public WindowFunction() {}
public void setWindowType(int wt) {
windowType = wt;
}
public void setWindowType(String w) {
if (w.toUpperCase().equals("RECTANGULAR"))
windowType = RECTANGULAR;
if (w.toUpperCase().equals("BARTLETT"))
windowType = BARTLETT;
if (w.toUpperCase().equals("HANNING"))
windowType = HANNING;
if (w.toUpperCase().equals("HAMMING"))
windowType = HAMMING;
if (w.toUpperCase().equals("BLACKMAN"))
windowType = BLACKMAN;
}
public int getWindowType() {
return windowType;
}
/**
* Generate a window
*
* @param nSamples size of the window
* @return window in array
*/
public double[] generate(int nSamples) {
// generate nSamples window function values
// for index values 0 .. nSamples - 1
int m = nSamples / 2;
double r;
double pi = Math.PI;
double[] w = new double[nSamples];
switch (windowType) {
case BARTLETT: // Bartlett (triangular) window
for (int n = 0; n < nSamples; n++)
w[n] = 1.0f - Math.abs(n - m) / m;
break;
case HANNING: // Hanning window
r = pi / (m + 1);
for (int n = -m; n < m; n++)
w[m + n] = 0.5f + 0.5f * Math.cos(n * r);
break;
case HAMMING: // Hamming window
r = pi / m;
for (int n = -m; n < m; n++)
w[m + n] = 0.54f + 0.46f * Math.cos(n * r);
break;
case BLACKMAN: // Blackman window
r = pi / m;
for (int n = -m; n < m; n++)
w[m + n] = 0.42f + 0.5f * Math.cos(n * r) + 0.08f * Math.cos(2 * n * r);
break;
default: // Rectangular window function
for (int n = 0; n < nSamples; n++)
w[n] = 1.0f;
}
return w;
}
}

View File

@ -1,21 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;

View File

@ -1,67 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.extension;
import org.datavec.audio.Wave;
public class NormalizedSampleAmplitudes {
private Wave wave;
private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame
public NormalizedSampleAmplitudes(Wave wave) {
this.wave = wave;
}
/**
*
* Get normalized amplitude of each frame
*
* @return array of normalized amplitudes(signed 16 bit): normalizedAmplitudes[frame]=amplitude
*/
public double[] getNormalizedAmplitudes() {
if (normalizedAmplitudes == null) {
boolean signed = true;
// usually 8bit is unsigned
if (wave.getWaveHeader().getBitsPerSample() == 8) {
signed = false;
}
short[] amplitudes = wave.getSampleAmplitudes();
int numSamples = amplitudes.length;
int maxAmplitude = 1 << (wave.getWaveHeader().getBitsPerSample() - 1);
if (!signed) { // one more bit for unsigned value
maxAmplitude <<= 1;
}
normalizedAmplitudes = new double[numSamples];
for (int i = 0; i < numSamples; i++) {
normalizedAmplitudes[i] = (double) amplitudes[i] / maxAmplitude;
}
}
return normalizedAmplitudes;
}
}

View File

@ -1,214 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.extension;
import org.datavec.audio.Wave;
import org.datavec.audio.dsp.FastFourierTransform;
import org.datavec.audio.dsp.WindowFunction;
public class Spectrogram {
public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024;
public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping
private Wave wave;
private double[][] spectrogram; // relative spectrogram
private double[][] absoluteSpectrogram; // absolute spectrogram
private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2
private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping
private int numFrames; // number of frames of the spectrogram
private int framesPerSecond; // frame per second of the spectrogram
private int numFrequencyUnit; // number of y-axis unit
private double unitFrequency; // frequency per y-axis unit
/**
* Constructor
*
* @param wave
*/
public Spectrogram(Wave wave) {
this.wave = wave;
// default
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
this.overlapFactor = SPECTROGRAM_DEFAULT_OVERLAP_FACTOR;
buildSpectrogram();
}
/**
* Constructor
*
* @param wave
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
*/
public Spectrogram(Wave wave, int fftSampleSize, int overlapFactor) {
this.wave = wave;
if (Integer.bitCount(fftSampleSize) == 1) {
this.fftSampleSize = fftSampleSize;
} else {
System.err.print("The input number must be a power of 2");
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
}
this.overlapFactor = overlapFactor;
buildSpectrogram();
}
/**
* Build spectrogram
*/
private void buildSpectrogram() {
short[] amplitudes = wave.getSampleAmplitudes();
int numSamples = amplitudes.length;
int pointer = 0;
// overlapping
if (overlapFactor > 1) {
int numOverlappedSamples = numSamples * overlapFactor;
int backSamples = fftSampleSize * (overlapFactor - 1) / overlapFactor;
short[] overlapAmp = new short[numOverlappedSamples];
pointer = 0;
for (int i = 0; i < amplitudes.length; i++) {
overlapAmp[pointer++] = amplitudes[i];
if (pointer % fftSampleSize == 0) {
// overlap
i -= backSamples;
}
}
numSamples = numOverlappedSamples;
amplitudes = overlapAmp;
}
// end overlapping
numFrames = numSamples / fftSampleSize;
framesPerSecond = (int) (numFrames / wave.length());
// set signals for fft
WindowFunction window = new WindowFunction();
window.setWindowType("Hamming");
double[] win = window.generate(fftSampleSize);
double[][] signals = new double[numFrames][];
for (int f = 0; f < numFrames; f++) {
signals[f] = new double[fftSampleSize];
int startSample = f * fftSampleSize;
for (int n = 0; n < fftSampleSize; n++) {
signals[f][n] = amplitudes[startSample + n] * win[n];
}
}
// end set signals for fft
absoluteSpectrogram = new double[numFrames][];
// for each frame in signals, do fft on it
FastFourierTransform fft = new FastFourierTransform();
for (int i = 0; i < numFrames; i++) {
absoluteSpectrogram[i] = fft.getMagnitudes(signals[i], false);
}
if (absoluteSpectrogram.length > 0) {
numFrequencyUnit = absoluteSpectrogram[0].length;
unitFrequency = (double) wave.getWaveHeader().getSampleRate() / 2 / numFrequencyUnit; // frequency could be caught within the half of nSamples according to Nyquist theory
// normalization of absoultSpectrogram
spectrogram = new double[numFrames][numFrequencyUnit];
// set max and min amplitudes
double maxAmp = Double.MIN_VALUE;
double minAmp = Double.MAX_VALUE;
for (int i = 0; i < numFrames; i++) {
for (int j = 0; j < numFrequencyUnit; j++) {
if (absoluteSpectrogram[i][j] > maxAmp) {
maxAmp = absoluteSpectrogram[i][j];
} else if (absoluteSpectrogram[i][j] < minAmp) {
minAmp = absoluteSpectrogram[i][j];
}
}
}
// end set max and min amplitudes
// normalization
// avoiding divided by zero
double minValidAmp = 0.00000000001F;
if (minAmp == 0) {
minAmp = minValidAmp;
}
double diff = Math.log10(maxAmp / minAmp); // perceptual difference
for (int i = 0; i < numFrames; i++) {
for (int j = 0; j < numFrequencyUnit; j++) {
if (absoluteSpectrogram[i][j] < minValidAmp) {
spectrogram[i][j] = 0;
} else {
spectrogram[i][j] = (Math.log10(absoluteSpectrogram[i][j] / minAmp)) / diff;
}
}
}
// end normalization
}
}
/**
* Get spectrogram: spectrogram[time][frequency]=intensity
*
* @return logarithm normalized spectrogram
*/
public double[][] getNormalizedSpectrogramData() {
return spectrogram;
}
/**
* Get spectrogram: spectrogram[time][frequency]=intensity
*
* @return absolute spectrogram
*/
public double[][] getAbsoluteSpectrogramData() {
return absoluteSpectrogram;
}
public int getNumFrames() {
return numFrames;
}
public int getFramesPerSecond() {
return framesPerSecond;
}
public int getNumFrequencyUnit() {
return numFrequencyUnit;
}
public double getUnitFrequency() {
return unitFrequency;
}
public int getFftSampleSize() {
return fftSampleSize;
}
public int getOverlapFactor() {
return overlapFactor;
}
}

View File

@ -1,272 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import lombok.extern.slf4j.Slf4j;
import org.datavec.audio.Wave;
import org.datavec.audio.WaveHeader;
import org.datavec.audio.dsp.Resampler;
import org.datavec.audio.extension.Spectrogram;
import org.datavec.audio.processor.TopManyPointsProcessorChain;
import org.datavec.audio.properties.FingerprintProperties;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
@Slf4j
public class FingerprintManager {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame();
private int overlapFactor = fingerprintProperties.getOverlapFactor();
private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame();
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
/**
* Constructor
*/
public FingerprintManager() {
}
/**
* Extract fingerprint from Wave object
*
* @param wave Wave Object to be extracted fingerprint
* @return fingerprint in bytes
*/
public byte[] extractFingerprint(Wave wave) {
int[][] coordinates; // coordinates[x][0..3]=y0..y3
byte[] fingerprint = new byte[0];
// resample to target rate
Resampler resampler = new Resampler();
int sourceRate = wave.getWaveHeader().getSampleRate();
int targetRate = fingerprintProperties.getSampleRate();
byte[] resampledWaveData = resampler.reSample(wave.getBytes(), wave.getWaveHeader().getBitsPerSample(),
sourceRate, targetRate);
// update the wave header
WaveHeader resampledWaveHeader = wave.getWaveHeader();
resampledWaveHeader.setSampleRate(targetRate);
// make resampled wave
Wave resampledWave = new Wave(resampledWaveHeader, resampledWaveData);
// end resample to target rate
// get spectrogram's data
Spectrogram spectrogram = resampledWave.getSpectrogram(sampleSizePerFrame, overlapFactor);
double[][] spectorgramData = spectrogram.getNormalizedSpectrogramData();
List<Integer>[] pointsLists = getRobustPointList(spectorgramData);
int numFrames = pointsLists.length;
// prepare fingerprint bytes
coordinates = new int[numFrames][numRobustPointsPerFrame];
for (int x = 0; x < numFrames; x++) {
if (pointsLists[x].size() == numRobustPointsPerFrame) {
Iterator<Integer> pointsListsIterator = pointsLists[x].iterator();
for (int y = 0; y < numRobustPointsPerFrame; y++) {
coordinates[x][y] = pointsListsIterator.next();
}
} else {
// use -1 to fill the empty byte
for (int y = 0; y < numRobustPointsPerFrame; y++) {
coordinates[x][y] = -1;
}
}
}
// end make fingerprint
// for each valid coordinate, append with its intensity
List<Byte> byteList = new LinkedList<Byte>();
for (int i = 0; i < numFrames; i++) {
for (int j = 0; j < numRobustPointsPerFrame; j++) {
if (coordinates[i][j] != -1) {
// first 2 bytes is x
byteList.add((byte) (i >> 8));
byteList.add((byte) i);
// next 2 bytes is y
int y = coordinates[i][j];
byteList.add((byte) (y >> 8));
byteList.add((byte) y);
// next 4 bytes is intensity
int intensity = (int) (spectorgramData[i][y] * Integer.MAX_VALUE); // spectorgramData is ranged from 0~1
byteList.add((byte) (intensity >> 24));
byteList.add((byte) (intensity >> 16));
byteList.add((byte) (intensity >> 8));
byteList.add((byte) intensity);
}
}
}
// end for each valid coordinate, append with its intensity
fingerprint = new byte[byteList.size()];
Iterator<Byte> byteListIterator = byteList.iterator();
int pointer = 0;
while (byteListIterator.hasNext()) {
fingerprint[pointer++] = byteListIterator.next();
}
return fingerprint;
}
/**
* Get bytes from fingerprint file
*
* @param fingerprintFile fingerprint filename
* @return fingerprint in bytes
*/
public byte[] getFingerprintFromFile(String fingerprintFile) {
byte[] fingerprint = null;
try {
InputStream fis = new FileInputStream(fingerprintFile);
fingerprint = getFingerprintFromInputStream(fis);
fis.close();
} catch (IOException e) {
log.error("",e);
}
return fingerprint;
}
/**
* Get bytes from fingerprint inputstream
*
* @param inputStream fingerprint inputstream
* @return fingerprint in bytes
*/
public byte[] getFingerprintFromInputStream(InputStream inputStream) {
byte[] fingerprint = null;
try {
fingerprint = new byte[inputStream.available()];
inputStream.read(fingerprint);
} catch (IOException e) {
log.error("",e);
}
return fingerprint;
}
/**
* Save fingerprint to a file
*
* @param fingerprint fingerprint bytes
* @param filename fingerprint filename
* @see FingerprintManager file saved
*/
public void saveFingerprintAsFile(byte[] fingerprint, String filename) {
FileOutputStream fileOutputStream;
try {
fileOutputStream = new FileOutputStream(filename);
fileOutputStream.write(fingerprint);
fileOutputStream.close();
} catch (IOException e) {
log.error("",e);
}
}
// robustLists[x]=y1,y2,y3,...
private List<Integer>[] getRobustPointList(double[][] spectrogramData) {
int numX = spectrogramData.length;
int numY = spectrogramData[0].length;
double[][] allBanksIntensities = new double[numX][numY];
int bandwidthPerBank = numY / numFilterBanks;
for (int b = 0; b < numFilterBanks; b++) {
double[][] bankIntensities = new double[numX][bandwidthPerBank];
for (int i = 0; i < numX; i++) {
System.arraycopy(spectrogramData[i], b * bandwidthPerBank, bankIntensities[i], 0, bandwidthPerBank);
}
// get the most robust point in each filter bank
TopManyPointsProcessorChain processorChain = new TopManyPointsProcessorChain(bankIntensities, 1);
double[][] processedIntensities = processorChain.getIntensities();
for (int i = 0; i < numX; i++) {
System.arraycopy(processedIntensities[i], 0, allBanksIntensities[i], b * bandwidthPerBank,
bandwidthPerBank);
}
}
List<int[]> robustPointList = new LinkedList<int[]>();
// find robust points
for (int i = 0; i < allBanksIntensities.length; i++) {
for (int j = 0; j < allBanksIntensities[i].length; j++) {
if (allBanksIntensities[i][j] > 0) {
int[] point = new int[] {i, j};
//System.out.println(i+","+frequency);
robustPointList.add(point);
}
}
}
// end find robust points
List<Integer>[] robustLists = new LinkedList[spectrogramData.length];
for (int i = 0; i < robustLists.length; i++) {
robustLists[i] = new LinkedList<>();
}
// robustLists[x]=y1,y2,y3,...
for (int[] coor : robustPointList) {
robustLists[coor[0]].add(coor[1]);
}
// return the list per frame
return robustLists;
}
/**
* Number of frames in a fingerprint
* Each frame lengths 8 bytes
* Usually there is more than one point in each frame, so it cannot simply divide the bytes length by 8
* Last 8 byte of thisFingerprint is the last frame of this wave
* First 2 byte of the last 8 byte is the x position of this wave, i.e. (number_of_frames-1) of this wave
*
* @param fingerprint fingerprint bytes
* @return number of frames of the fingerprint
*/
public static int getNumFrames(byte[] fingerprint) {
if (fingerprint.length < 8) {
return 0;
}
// get the last x-coordinate (length-8&length-7)bytes from fingerprint
return ((fingerprint[fingerprint.length - 8] & 0xff) << 8 | (fingerprint[fingerprint.length - 7] & 0xff)) + 1;
}
}

View File

@ -1,107 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import org.datavec.audio.properties.FingerprintProperties;
public class FingerprintSimilarity {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
private int mostSimilarFramePosition;
private float score;
private float similarity;
/**
* Constructor
*/
public FingerprintSimilarity() {
mostSimilarFramePosition = Integer.MIN_VALUE;
score = -1;
similarity = -1;
}
/**
* Get the most similar position in terms of frame number
*
* @return most similar frame position
*/
public int getMostSimilarFramePosition() {
return mostSimilarFramePosition;
}
/**
* Set the most similar position in terms of frame number
*
* @param mostSimilarFramePosition
*/
public void setMostSimilarFramePosition(int mostSimilarFramePosition) {
this.mostSimilarFramePosition = mostSimilarFramePosition;
}
/**
* Get the similarity of the fingerprints
* similarity from 0~1, which 0 means no similar feature is found and 1 means in average there is at least one match in every frame
*
* @return fingerprints similarity
*/
public float getSimilarity() {
return similarity;
}
/**
* Set the similarity of the fingerprints
*
* @param similarity similarity
*/
public void setSimilarity(float similarity) {
this.similarity = similarity;
}
/**
* Get the similarity score of the fingerprints
* Number of features found in the fingerprints per frame
*
* @return fingerprints similarity score
*/
public float getScore() {
return score;
}
/**
* Set the similarity score of the fingerprints
*
* @param score
*/
public void setScore(float score) {
this.score = score;
}
/**
* Get the most similar position in terms of time in second
*
* @return most similar starting time
*/
public float getsetMostSimilarTimePosition() {
return (float) mostSimilarFramePosition / fingerprintProperties.getNumRobustPointsPerFrame()
/ fingerprintProperties.getFps();
}
}

View File

@ -1,134 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.HashMap;
import java.util.List;
public class FingerprintSimilarityComputer {
private FingerprintSimilarity fingerprintSimilarity;
byte[] fingerprint1, fingerprint2;
/**
* Constructor, ready to compute the similarity of two fingerprints
*
* @param fingerprint1
* @param fingerprint2
*/
public FingerprintSimilarityComputer(byte[] fingerprint1, byte[] fingerprint2) {
this.fingerprint1 = fingerprint1;
this.fingerprint2 = fingerprint2;
fingerprintSimilarity = new FingerprintSimilarity();
}
/**
* Get fingerprint similarity of inout fingerprints
*
* @return fingerprint similarity object
*/
public FingerprintSimilarity getFingerprintsSimilarity() {
HashMap<Integer, Integer> offset_Score_Table = new HashMap<>(); // offset_Score_Table<offset,count>
int numFrames;
float score = 0;
int mostSimilarFramePosition = Integer.MIN_VALUE;
// one frame may contain several points, use the shorter one be the denominator
if (fingerprint1.length > fingerprint2.length) {
numFrames = FingerprintManager.getNumFrames(fingerprint2);
} else {
numFrames = FingerprintManager.getNumFrames(fingerprint1);
}
// get the pairs
PairManager pairManager = new PairManager();
HashMap<Integer, List<Integer>> this_Pair_PositionList_Table =
pairManager.getPair_PositionList_Table(fingerprint1);
HashMap<Integer, List<Integer>> compareWave_Pair_PositionList_Table =
pairManager.getPair_PositionList_Table(fingerprint2);
for (Integer compareWaveHashNumber : compareWave_Pair_PositionList_Table.keySet()) {
// if the compareWaveHashNumber doesn't exist in both tables, no need to compare
if (!this_Pair_PositionList_Table.containsKey(compareWaveHashNumber)
|| !compareWave_Pair_PositionList_Table.containsKey(compareWaveHashNumber)) {
continue;
}
// for each compare hash number, get the positions
List<Integer> wavePositionList = this_Pair_PositionList_Table.get(compareWaveHashNumber);
List<Integer> compareWavePositionList = compareWave_Pair_PositionList_Table.get(compareWaveHashNumber);
for (Integer thisPosition : wavePositionList) {
for (Integer compareWavePosition : compareWavePositionList) {
int offset = thisPosition - compareWavePosition;
if (offset_Score_Table.containsKey(offset)) {
offset_Score_Table.put(offset, offset_Score_Table.get(offset) + 1);
} else {
offset_Score_Table.put(offset, 1);
}
}
}
}
// map rank
MapRank mapRank = new MapRankInteger(offset_Score_Table, false);
// get the most similar positions and scores
List<Integer> orderedKeyList = mapRank.getOrderedKeyList(100, true);
if (orderedKeyList.size() > 0) {
int key = orderedKeyList.get(0);
// get the highest score position
mostSimilarFramePosition = key;
score = offset_Score_Table.get(key);
// accumulate the scores from neighbours
if (offset_Score_Table.containsKey(key - 1)) {
score += offset_Score_Table.get(key - 1) / 2;
}
if (offset_Score_Table.containsKey(key + 1)) {
score += offset_Score_Table.get(key + 1) / 2;
}
}
/*
Iterator<Integer> orderedKeyListIterator=orderedKeyList.iterator();
while (orderedKeyListIterator.hasNext()){
int offset=orderedKeyListIterator.next();
System.out.println(offset+": "+offset_Score_Table.get(offset));
}
*/
score /= numFrames;
float similarity = score;
// similarity >1 means in average there is at least one match in every frame
if (similarity > 1) {
similarity = 1;
}
fingerprintSimilarity.setMostSimilarFramePosition(mostSimilarFramePosition);
fingerprintSimilarity.setScore(score);
fingerprintSimilarity.setSimilarity(similarity);
return fingerprintSimilarity;
}
}

View File

@ -1,27 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.List;
public interface MapRank {
public List getOrderedKeyList(int numKeys, boolean sharpLimit);
}

View File

@ -1,179 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.*;
import java.util.Map.Entry;
public class MapRankDouble implements MapRank {
private Map map;
private boolean acsending = true;
public MapRankDouble(Map<?, Double> map, boolean acsending) {
this.map = map;
this.acsending = acsending;
}
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
Set mapEntrySet = map.entrySet();
List keyList = new LinkedList();
// if the numKeys is larger than map size, limit it
if (numKeys > map.size()) {
numKeys = map.size();
}
// end if the numKeys is larger than map size, limit it
if (map.size() > 0) {
double[] array = new double[map.size()];
int count = 0;
// get the pass values
Iterator<Entry> mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
array[count++] = (Double) entry.getValue();
}
// end get the pass values
int targetindex;
if (acsending) {
targetindex = numKeys;
} else {
targetindex = array.length - numKeys;
}
double passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
// get the passed keys and values
Map passedMap = new HashMap();
List<Double> valueList = new LinkedList<Double>();
mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
double value = (Double) entry.getValue();
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
passedMap.put(entry.getKey(), value);
valueList.add(value);
}
}
// end get the passed keys and values
// sort the value list
Double[] listArr = new Double[valueList.size()];
valueList.toArray(listArr);
Arrays.sort(listArr);
// end sort the value list
// get the list of keys
int resultCount = 0;
int index;
if (acsending) {
index = 0;
} else {
index = listArr.length - 1;
}
if (!sharpLimit) {
numKeys = listArr.length;
}
while (true) {
double targetValue = (Double) listArr[index];
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
while (passedMapIterator.hasNext()) {
Entry entry = passedMapIterator.next();
if ((Double) entry.getValue() == targetValue) {
keyList.add(entry.getKey());
passedMapIterator.remove();
resultCount++;
break;
}
}
if (acsending) {
index++;
} else {
index--;
}
if (resultCount >= numKeys) {
break;
}
}
// end get the list of keys
}
return keyList;
}
private double getOrderedValue(double[] array, int index) {
locate(array, 0, array.length - 1, index);
return array[index];
}
// sort the partitions by quick sort, and locate the target index
private void locate(double[] array, int left, int right, int index) {
int mid = (left + right) / 2;
//System.out.println(left+" to "+right+" ("+mid+")");
if (right == left) {
//System.out.println("* "+array[targetIndex]);
//result=array[targetIndex];
return;
}
if (left < right) {
double s = array[mid];
int i = left - 1;
int j = right + 1;
while (true) {
while (array[++i] < s);
while (array[--j] > s);
if (i >= j)
break;
swap(array, i, j);
}
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
if (i > index) {
// the target index in the left partition
//System.out.println("left partition");
locate(array, left, i - 1, index);
} else {
// the target index in the right partition
//System.out.println("right partition");
locate(array, j + 1, right, index);
}
}
}
private void swap(double[] array, int i, int j) {
double t = array[i];
array[i] = array[j];
array[j] = t;
}
}

View File

@ -1,179 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.*;
import java.util.Map.Entry;
public class MapRankInteger implements MapRank {
private Map map;
private boolean acsending = true;
public MapRankInteger(Map<?, Integer> map, boolean acsending) {
this.map = map;
this.acsending = acsending;
}
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
Set mapEntrySet = map.entrySet();
List keyList = new LinkedList();
// if the numKeys is larger than map size, limit it
if (numKeys > map.size()) {
numKeys = map.size();
}
// end if the numKeys is larger than map size, limit it
if (map.size() > 0) {
int[] array = new int[map.size()];
int count = 0;
// get the pass values
Iterator<Entry> mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
array[count++] = (Integer) entry.getValue();
}
// end get the pass values
int targetindex;
if (acsending) {
targetindex = numKeys;
} else {
targetindex = array.length - numKeys;
}
int passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
// get the passed keys and values
Map passedMap = new HashMap();
List<Integer> valueList = new LinkedList<Integer>();
mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
int value = (Integer) entry.getValue();
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
passedMap.put(entry.getKey(), value);
valueList.add(value);
}
}
// end get the passed keys and values
// sort the value list
Integer[] listArr = new Integer[valueList.size()];
valueList.toArray(listArr);
Arrays.sort(listArr);
// end sort the value list
// get the list of keys
int resultCount = 0;
int index;
if (acsending) {
index = 0;
} else {
index = listArr.length - 1;
}
if (!sharpLimit) {
numKeys = listArr.length;
}
while (true) {
int targetValue = (Integer) listArr[index];
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
while (passedMapIterator.hasNext()) {
Entry entry = passedMapIterator.next();
if ((Integer) entry.getValue() == targetValue) {
keyList.add(entry.getKey());
passedMapIterator.remove();
resultCount++;
break;
}
}
if (acsending) {
index++;
} else {
index--;
}
if (resultCount >= numKeys) {
break;
}
}
// end get the list of keys
}
return keyList;
}
private int getOrderedValue(int[] array, int index) {
locate(array, 0, array.length - 1, index);
return array[index];
}
// sort the partitions by quick sort, and locate the target index
private void locate(int[] array, int left, int right, int index) {
int mid = (left + right) / 2;
//System.out.println(left+" to "+right+" ("+mid+")");
if (right == left) {
//System.out.println("* "+array[targetIndex]);
//result=array[targetIndex];
return;
}
if (left < right) {
int s = array[mid];
int i = left - 1;
int j = right + 1;
while (true) {
while (array[++i] < s);
while (array[--j] > s);
if (i >= j)
break;
swap(array, i, j);
}
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
if (i > index) {
// the target index in the left partition
//System.out.println("left partition");
locate(array, left, i - 1, index);
} else {
// the target index in the right partition
//System.out.println("right partition");
locate(array, j + 1, right, index);
}
}
}
private void swap(int[] array, int i, int j) {
int t = array[i];
array[i] = array[j];
array[j] = t;
}
}

View File

@ -1,230 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import org.datavec.audio.properties.FingerprintProperties;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
public class PairManager {
FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks;
private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength();
private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval();
private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance();
private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits();
private int maxPairs;
private boolean isReferencePairing;
private HashMap<Integer, Boolean> stopPairTable = new HashMap<>();
/**
* Constructor
*/
public PairManager() {
maxPairs = fingerprintProperties.getRefMaxActivePairs();
isReferencePairing = true;
}
/**
* Constructor, number of pairs of robust points depends on the parameter isReferencePairing
* no. of pairs of reference and sample can be different due to environmental influence of source
* @param isReferencePairing
*/
public PairManager(boolean isReferencePairing) {
if (isReferencePairing) {
maxPairs = fingerprintProperties.getRefMaxActivePairs();
} else {
maxPairs = fingerprintProperties.getSampleMaxActivePairs();
}
this.isReferencePairing = isReferencePairing;
}
/**
* Get a pair-positionList table
* It's a hash map which the key is the hashed pair, and the value is list of positions
* That means the table stores the positions which have the same hashed pair
*
* @param fingerprint fingerprint bytes
* @return pair-positionList HashMap
*/
public HashMap<Integer, List<Integer>> getPair_PositionList_Table(byte[] fingerprint) {
List<int[]> pairPositionList = getPairPositionList(fingerprint);
// table to store pair:pos,pos,pos,...;pair2:pos,pos,pos,....
HashMap<Integer, List<Integer>> pair_positionList_table = new HashMap<>();
// get all pair_positions from list, use a table to collect the data group by pair hashcode
for (int[] pair_position : pairPositionList) {
//System.out.println(pair_position[0]+","+pair_position[1]);
// group by pair-hashcode, i.e.: <pair,List<position>>
if (pair_positionList_table.containsKey(pair_position[0])) {
pair_positionList_table.get(pair_position[0]).add(pair_position[1]);
} else {
List<Integer> positionList = new LinkedList<>();
positionList.add(pair_position[1]);
pair_positionList_table.put(pair_position[0], positionList);
}
// end group by pair-hashcode, i.e.: <pair,List<position>>
}
// end get all pair_positions from list, use a table to collect the data group by pair hashcode
return pair_positionList_table;
}
// this return list contains: int[0]=pair_hashcode, int[1]=position
private List<int[]> getPairPositionList(byte[] fingerprint) {
int numFrames = FingerprintManager.getNumFrames(fingerprint);
// table for paired frames
byte[] pairedFrameTable = new byte[numFrames / anchorPointsIntervalLength + 1]; // each second has numAnchorPointsPerSecond pairs only
// end table for paired frames
List<int[]> pairList = new LinkedList<>();
List<int[]> sortedCoordinateList = getSortedCoordinateList(fingerprint);
for (int[] anchorPoint : sortedCoordinateList) {
int anchorX = anchorPoint[0];
int anchorY = anchorPoint[1];
int numPairs = 0;
for (int[] aSortedCoordinateList : sortedCoordinateList) {
if (numPairs >= maxPairs) {
break;
}
if (isReferencePairing && pairedFrameTable[anchorX
/ anchorPointsIntervalLength] >= numAnchorPointsPerInterval) {
break;
}
int targetX = aSortedCoordinateList[0];
int targetY = aSortedCoordinateList[1];
if (anchorX == targetX && anchorY == targetY) {
continue;
}
// pair up the points
int x1, y1, x2, y2; // x2 always >= x1
if (targetX >= anchorX) {
x2 = targetX;
y2 = targetY;
x1 = anchorX;
y1 = anchorY;
} else {
x2 = anchorX;
y2 = anchorY;
x1 = targetX;
y1 = targetY;
}
// check target zone
if ((x2 - x1) > maxTargetZoneDistance) {
continue;
}
// end check target zone
// check filter bank zone
if (!(y1 / bandwidthPerBank == y2 / bandwidthPerBank)) {
continue; // same filter bank should have equal value
}
// end check filter bank zone
int pairHashcode = (x2 - x1) * numFrequencyUnits * numFrequencyUnits + y2 * numFrequencyUnits + y1;
// stop list applied on sample pairing only
if (!isReferencePairing && stopPairTable.containsKey(pairHashcode)) {
numPairs++; // no reservation
continue; // escape this point only
}
// end stop list applied on sample pairing only
// pass all rules
pairList.add(new int[] {pairHashcode, anchorX});
pairedFrameTable[anchorX / anchorPointsIntervalLength]++;
numPairs++;
// end pair up the points
}
}
return pairList;
}
private List<int[]> getSortedCoordinateList(byte[] fingerprint) {
// each point data is 8 bytes
// first 2 bytes is x
// next 2 bytes is y
// next 4 bytes is intensity
// get all intensities
int numCoordinates = fingerprint.length / 8;
int[] intensities = new int[numCoordinates];
for (int i = 0; i < numCoordinates; i++) {
int pointer = i * 8 + 4;
int intensity = (fingerprint[pointer] & 0xff) << 24 | (fingerprint[pointer + 1] & 0xff) << 16
| (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
intensities[i] = intensity;
}
QuickSortIndexPreserved quicksort = new QuickSortIndexPreserved(intensities);
int[] sortIndexes = quicksort.getSortIndexes();
List<int[]> sortedCoordinateList = new LinkedList<>();
for (int i = sortIndexes.length - 1; i >= 0; i--) {
int pointer = sortIndexes[i] * 8;
int x = (fingerprint[pointer] & 0xff) << 8 | (fingerprint[pointer + 1] & 0xff);
int y = (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
sortedCoordinateList.add(new int[] {x, y});
}
return sortedCoordinateList;
}
/**
* Convert hashed pair to bytes
*
* @param pairHashcode hashed pair
* @return byte array
*/
public static byte[] pairHashcodeToBytes(int pairHashcode) {
return new byte[] {(byte) (pairHashcode >> 8), (byte) pairHashcode};
}
/**
* Convert bytes to hased pair
*
* @param pairBytes
* @return hashed pair
*/
public static int pairBytesToHashcode(byte[] pairBytes) {
return (pairBytes[0] & 0xFF) << 8 | (pairBytes[1] & 0xFF);
}
}

View File

@ -1,25 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public abstract class QuickSort {
public abstract int[] getSortIndexes();
}

View File

@ -1,79 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortDouble extends QuickSort {
private int[] indexes;
private double[] array;
public QuickSortDouble(double[] array) {
this.array = array;
indexes = new int[array.length];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
}
public int[] getSortIndexes() {
sort();
return indexes;
}
private void sort() {
quicksort(array, indexes, 0, indexes.length - 1);
}
// quicksort a[left] to a[right]
private void quicksort(double[] a, int[] indexes, int left, int right) {
if (right <= left)
return;
int i = partition(a, indexes, left, right);
quicksort(a, indexes, left, i - 1);
quicksort(a, indexes, i + 1, right);
}
// partition a[left] to a[right], assumes left < right
private int partition(double[] a, int[] indexes, int left, int right) {
int i = left - 1;
int j = right;
while (true) {
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
if (j == left)
break; // don't go out-of-bounds
}
if (i >= j)
break; // check if pointers cross
swap(a, indexes, i, j); // swap two elements into place
}
swap(a, indexes, i, right); // swap with partition element
return i;
}
// exchange a[i] and a[j]
private void swap(double[] a, int[] indexes, int i, int j) {
int swap = indexes[i];
indexes[i] = indexes[j];
indexes[j] = swap;
}
}

View File

@ -1,43 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortIndexPreserved {
private QuickSort quickSort;
public QuickSortIndexPreserved(int[] array) {
quickSort = new QuickSortInteger(array);
}
public QuickSortIndexPreserved(double[] array) {
quickSort = new QuickSortDouble(array);
}
public QuickSortIndexPreserved(short[] array) {
quickSort = new QuickSortShort(array);
}
public int[] getSortIndexes() {
return quickSort.getSortIndexes();
}
}

View File

@ -1,79 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortInteger extends QuickSort {
private int[] indexes;
private int[] array;
public QuickSortInteger(int[] array) {
this.array = array;
indexes = new int[array.length];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
}
public int[] getSortIndexes() {
sort();
return indexes;
}
private void sort() {
quicksort(array, indexes, 0, indexes.length - 1);
}
// quicksort a[left] to a[right]
private void quicksort(int[] a, int[] indexes, int left, int right) {
if (right <= left)
return;
int i = partition(a, indexes, left, right);
quicksort(a, indexes, left, i - 1);
quicksort(a, indexes, i + 1, right);
}
// partition a[left] to a[right], assumes left < right
private int partition(int[] a, int[] indexes, int left, int right) {
int i = left - 1;
int j = right;
while (true) {
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
if (j == left)
break; // don't go out-of-bounds
}
if (i >= j)
break; // check if pointers cross
swap(a, indexes, i, j); // swap two elements into place
}
swap(a, indexes, i, right); // swap with partition element
return i;
}
// exchange a[i] and a[j]
private void swap(int[] a, int[] indexes, int i, int j) {
int swap = indexes[i];
indexes[i] = indexes[j];
indexes[j] = swap;
}
}

View File

@ -1,79 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortShort extends QuickSort {
private int[] indexes;
private short[] array;
public QuickSortShort(short[] array) {
this.array = array;
indexes = new int[array.length];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
}
public int[] getSortIndexes() {
sort();
return indexes;
}
private void sort() {
quicksort(array, indexes, 0, indexes.length - 1);
}
// quicksort a[left] to a[right]
private void quicksort(short[] a, int[] indexes, int left, int right) {
if (right <= left)
return;
int i = partition(a, indexes, left, right);
quicksort(a, indexes, left, i - 1);
quicksort(a, indexes, i + 1, right);
}
// partition a[left] to a[right], assumes left < right
private int partition(short[] a, int[] indexes, int left, int right) {
int i = left - 1;
int j = right;
while (true) {
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
if (j == left)
break; // don't go out-of-bounds
}
if (i >= j)
break; // check if pointers cross
swap(a, indexes, i, j); // swap two elements into place
}
swap(a, indexes, i, right); // swap with partition element
return i;
}
// exchange a[i] and a[j]
private void swap(short[] a, int[] indexes, int i, int j) {
int swap = indexes[i];
indexes[i] = indexes[j];
indexes[j] = swap;
}
}

View File

@ -1,51 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.formats.input;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.BaseInputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.audio.recordreader.WavFileRecordReader;
import java.io.IOException;
/**
*
* Wave file input format
*
* @author Adam Gibson
*/
public class WavInputFormat extends BaseInputFormat {
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
return createReader(split);
}
@Override
public RecordReader createReader(InputSplit split) throws IOException, InterruptedException {
RecordReader waveRecordReader = new WavFileRecordReader();
waveRecordReader.initialize(split);
return waveRecordReader;
}
}

View File

@ -1,36 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.formats.output;
import org.datavec.api.conf.Configuration;
import org.datavec.api.exceptions.DataVecException;
import org.datavec.api.formats.output.OutputFormat;
import org.datavec.api.records.writer.RecordWriter;
/**
* @author Adam Gibson
*/
public class WaveOutputFormat implements OutputFormat {
@Override
public RecordWriter createWriter(Configuration conf) throws DataVecException {
return null;
}
}

View File

@ -1,139 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
public class ArrayRankDouble {
/**
* Get the index position of maximum value the given array
* @param array an array
* @return index of the max value in array
*/
public int getMaxValueIndex(double[] array) {
int index = 0;
double max = Integer.MIN_VALUE;
for (int i = 0; i < array.length; i++) {
if (array[i] > max) {
max = array[i];
index = i;
}
}
return index;
}
/**
* Get the index position of minimum value in the given array
* @param array an array
* @return index of the min value in array
*/
public int getMinValueIndex(double[] array) {
int index = 0;
double min = Integer.MAX_VALUE;
for (int i = 0; i < array.length; i++) {
if (array[i] < min) {
min = array[i];
index = i;
}
}
return index;
}
/**
* Get the n-th value in the array after sorted
* @param array an array
* @param n position in array
* @param ascending is ascending order or not
* @return value at nth position of array
*/
public double getNthOrderedValue(double[] array, int n, boolean ascending) {
if (n > array.length) {
n = array.length;
}
int targetindex;
if (ascending) {
targetindex = n;
} else {
targetindex = array.length - n;
}
// this value is the value of the numKey-th element
return getOrderedValue(array, targetindex);
}
private double getOrderedValue(double[] array, int index) {
locate(array, 0, array.length - 1, index);
return array[index];
}
// sort the partitions by quick sort, and locate the target index
private void locate(double[] array, int left, int right, int index) {
int mid = (left + right) / 2;
// System.out.println(left+" to "+right+" ("+mid+")");
if (right == left) {
// System.out.println("* "+array[targetIndex]);
// result=array[targetIndex];
return;
}
if (left < right) {
double s = array[mid];
int i = left - 1;
int j = right + 1;
while (true) {
while (array[++i] < s);
while (array[--j] > s);
if (i >= j)
break;
swap(array, i, j);
}
// System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
if (i > index) {
// the target index in the left partition
// System.out.println("left partition");
locate(array, left, i - 1, index);
} else {
// the target index in the right partition
// System.out.println("right partition");
locate(array, j + 1, right, index);
}
}
}
private void swap(double[] array, int i, int j) {
double t = array[i];
array[i] = array[j];
array[j] = t;
}
}

View File

@ -1,28 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
public interface IntensityProcessor {
public void execute();
public double[][] getIntensities();
}

View File

@ -1,48 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
import java.util.LinkedList;
import java.util.List;
public class ProcessorChain {
private double[][] intensities;
List<IntensityProcessor> processorList = new LinkedList<IntensityProcessor>();
public ProcessorChain(double[][] intensities) {
this.intensities = intensities;
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, 1);
processorList.add(robustProcessor);
process();
}
private void process() {
for (IntensityProcessor processor : processorList) {
processor.execute();
intensities = processor.getIntensities();
}
}
public double[][] getIntensities() {
return intensities;
}
}

View File

@ -1,61 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
public class RobustIntensityProcessor implements IntensityProcessor {
private double[][] intensities;
private int numPointsPerFrame;
public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) {
this.intensities = intensities;
this.numPointsPerFrame = numPointsPerFrame;
}
public void execute() {
int numX = intensities.length;
int numY = intensities[0].length;
double[][] processedIntensities = new double[numX][numY];
for (int i = 0; i < numX; i++) {
double[] tmpArray = new double[numY];
System.arraycopy(intensities[i], 0, tmpArray, 0, numY);
// pass value is the last some elements in sorted array
ArrayRankDouble arrayRankDouble = new ArrayRankDouble();
double passValue = arrayRankDouble.getNthOrderedValue(tmpArray, numPointsPerFrame, false);
// only passed elements will be assigned a value
for (int j = 0; j < numY; j++) {
if (intensities[i][j] >= passValue) {
processedIntensities[i][j] = intensities[i][j];
}
}
}
intensities = processedIntensities;
}
public double[][] getIntensities() {
return intensities;
}
}

View File

@ -1,49 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
import java.util.LinkedList;
import java.util.List;
public class TopManyPointsProcessorChain {
private double[][] intensities;
List<IntensityProcessor> processorList = new LinkedList<>();
public TopManyPointsProcessorChain(double[][] intensities, int numPoints) {
this.intensities = intensities;
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, numPoints);
processorList.add(robustProcessor);
process();
}
private void process() {
for (IntensityProcessor processor : processorList) {
processor.execute();
intensities = processor.getIntensities();
}
}
public double[][] getIntensities() {
return intensities;
}
}

View File

@ -1,121 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.properties;
public class FingerprintProperties {
protected static FingerprintProperties instance = null;
private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint
private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size
private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ... 32
private int numFilterBanks = 4;
private int upperBoundedFrequency = 1500; // low pass
private int lowerBoundedFrequency = 400; // high pass
private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps)
private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps
private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually
private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs
private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip
private int numAnchorPointsPerInterval = 10;
private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second)
private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second)
private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units
public static FingerprintProperties getInstance() {
if (instance == null) {
synchronized (FingerprintProperties.class) {
if (instance == null) {
instance = new FingerprintProperties();
}
}
}
return instance;
}
public int getNumRobustPointsPerFrame() {
return numRobustPointsPerFrame;
}
public int getSampleSizePerFrame() {
return sampleSizePerFrame;
}
public int getOverlapFactor() {
return overlapFactor;
}
public int getNumFilterBanks() {
return numFilterBanks;
}
public int getUpperBoundedFrequency() {
return upperBoundedFrequency;
}
public int getLowerBoundedFrequency() {
return lowerBoundedFrequency;
}
public int getFps() {
return fps;
}
public int getRefMaxActivePairs() {
return refMaxActivePairs;
}
public int getSampleMaxActivePairs() {
return sampleMaxActivePairs;
}
public int getNumAnchorPointsPerInterval() {
return numAnchorPointsPerInterval;
}
public int getAnchorPointsIntervalLength() {
return anchorPointsIntervalLength;
}
public int getMaxTargetZoneDistance() {
return maxTargetZoneDistance;
}
public int getNumFrequencyUnits() {
return numFrequencyUnits;
}
public int getMaxPossiblePairHashcode() {
return maxTargetZoneDistance * numFrequencyUnits * numFrequencyUnits + numFrequencyUnits * numFrequencyUnits
+ numFrequencyUnits;
}
public int getSampleRate() {
return sampleRate;
}
public int getNumFramesInOneSecond() {
return numFramesInOneSecond;
}
}

View File

@ -1,225 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.recordreader;
import org.apache.commons.io.FileUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.split.BaseInputSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
/**
* Base audio file loader
* @author Adam Gibson
*/
public abstract class BaseAudioRecordReader extends BaseRecordReader {
private Iterator<File> iter;
private List<Writable> record;
private boolean hitImage = false;
private boolean appendLabel = false;
private List<String> labels = new ArrayList<>();
private Configuration conf;
protected InputSplit inputSplit;
public BaseAudioRecordReader() {}
public BaseAudioRecordReader(boolean appendLabel, List<String> labels) {
this.appendLabel = appendLabel;
this.labels = labels;
}
public BaseAudioRecordReader(List<String> labels) {
this.labels = labels;
}
public BaseAudioRecordReader(boolean appendLabel) {
this.appendLabel = appendLabel;
}
protected abstract List<Writable> loadData(File file, InputStream inputStream) throws IOException;
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
inputSplit = split;
if (split instanceof BaseInputSplit) {
URI[] locations = split.locations();
if (locations != null && locations.length >= 1) {
if (locations.length > 1) {
List<File> allFiles = new ArrayList<>();
for (URI location : locations) {
File iter = new File(location);
if (iter.isDirectory()) {
Iterator<File> allFiles2 = FileUtils.iterateFiles(iter, null, true);
while (allFiles2.hasNext())
allFiles.add(allFiles2.next());
}
else
allFiles.add(iter);
}
iter = allFiles.iterator();
} else {
File curr = new File(locations[0]);
if (curr.isDirectory())
iter = FileUtils.iterateFiles(curr, null, true);
else
iter = Collections.singletonList(curr).iterator();
}
}
}
else if (split instanceof InputStreamInputSplit) {
record = new ArrayList<>();
InputStreamInputSplit split2 = (InputStreamInputSplit) split;
InputStream is = split2.getIs();
URI[] locations = split2.locations();
if (appendLabel) {
Path path = Paths.get(locations[0]);
String parent = path.getParent().toString();
record.add(new DoubleWritable(labels.indexOf(parent)));
}
is.close();
}
}
@Override
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
this.conf = conf;
this.appendLabel = conf.getBoolean(APPEND_LABEL, false);
this.labels = new ArrayList<>(conf.getStringCollection(LABELS));
initialize(split);
}
@Override
public List<Writable> next() {
if (iter != null) {
File next = iter.next();
invokeListeners(next);
try {
return loadData(next, null);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else if (record != null) {
hitImage = true;
return record;
}
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
@Override
public boolean hasNext() {
if (iter != null) {
return iter.hasNext();
} else if (record != null) {
return !hitImage;
}
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
@Override
public void close() throws IOException {
}
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public Configuration getConf() {
return conf;
}
@Override
public List<String> getLabels() {
return null;
}
@Override
public void reset() {
if (inputSplit == null)
throw new UnsupportedOperationException("Cannot reset without first initializing");
try {
initialize(inputSplit);
} catch (Exception e) {
throw new RuntimeException("Error during LineRecordReader reset", e);
}
}
@Override
public boolean resetSupported(){
if(inputSplit == null){
return false;
}
return inputSplit.resetSupported();
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
invokeListeners(uri);
try {
return loadData(null, dataInputStream);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Record nextRecord() {
return new org.datavec.api.records.impl.Record(next(), null);
}
@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
}
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
}
}

View File

@ -1,76 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.recordreader;
import org.bytedeco.javacv.FFmpegFrameGrabber;
import org.bytedeco.javacv.Frame;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.Writable;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
import static org.bytedeco.ffmpeg.global.avutil.AV_SAMPLE_FMT_FLT;
/**
* Native audio file loader using FFmpeg.
*
* @author saudet
*/
public class NativeAudioRecordReader extends BaseAudioRecordReader {
public NativeAudioRecordReader() {}
public NativeAudioRecordReader(boolean appendLabel, List<String> labels) {
super(appendLabel, labels);
}
public NativeAudioRecordReader(List<String> labels) {
super(labels);
}
public NativeAudioRecordReader(boolean appendLabel) {
super(appendLabel);
}
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
List<Writable> ret = new ArrayList<>();
try (FFmpegFrameGrabber grabber = inputStream != null ? new FFmpegFrameGrabber(inputStream)
: new FFmpegFrameGrabber(file.getAbsolutePath())) {
grabber.setSampleFormat(AV_SAMPLE_FMT_FLT);
grabber.start();
Frame frame;
while ((frame = grabber.grab()) != null) {
while (frame.samples != null && frame.samples[0].hasRemaining()) {
for (int i = 0; i < frame.samples.length; i++) {
ret.add(new FloatWritable(((FloatBuffer) frame.samples[i]).get()));
}
}
}
}
return ret;
}
}

View File

@ -1,57 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.recordreader;
import org.datavec.api.util.RecordUtils;
import org.datavec.api.writable.Writable;
import org.datavec.audio.Wave;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
/**
* Wav file loader
* @author Adam Gibson
*/
public class WavFileRecordReader extends BaseAudioRecordReader {
public WavFileRecordReader() {}
public WavFileRecordReader(boolean appendLabel, List<String> labels) {
super(appendLabel, labels);
}
public WavFileRecordReader(List<String> labels) {
super(labels);
}
public WavFileRecordReader(boolean appendLabel) {
super(appendLabel);
}
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
Wave wave = inputStream != null ? new Wave(inputStream) : new Wave(file.getAbsolutePath());
return RecordUtils.toRecord(wave.getNormalizedAmplitudes());
}
}

View File

@ -1,51 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
public long getTimeoutMilliseconds() {
return 60000;
}
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.audio";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,68 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import org.bytedeco.javacv.FFmpegFrameRecorder;
import org.bytedeco.javacv.Frame;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.datavec.audio.recordreader.NativeAudioRecordReader;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.ShortBuffer;
import java.util.List;
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_VORBIS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author saudet
*/
public class AudioReaderTest extends BaseND4JTest {
@Ignore
@Test
public void testNativeAudioReader() throws Exception {
File tempFile = File.createTempFile("testNativeAudioReader", ".ogg");
FFmpegFrameRecorder recorder = new FFmpegFrameRecorder(tempFile, 2);
recorder.setAudioCodec(AV_CODEC_ID_VORBIS);
recorder.setSampleRate(44100);
recorder.start();
Frame audioFrame = new Frame();
ShortBuffer audioBuffer = ShortBuffer.allocate(64 * 1024);
audioFrame.sampleRate = 44100;
audioFrame.audioChannels = 2;
audioFrame.samples = new ShortBuffer[] {audioBuffer};
recorder.record(audioFrame);
recorder.stop();
recorder.release();
RecordReader reader = new NativeAudioRecordReader();
reader.initialize(new FileSplit(tempFile));
assertTrue(reader.hasNext());
List<Writable> record = reader.next();
assertEquals(audioBuffer.limit(), record.size());
}
}

View File

@ -1,69 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import org.datavec.audio.dsp.FastFourierTransform;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.common.tests.BaseND4JTest;
public class TestFastFourierTransform extends BaseND4JTest {
@Test
public void testFastFourierTransformComplex() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
double[] frequencies = fft.getMagnitudes(amplitudes);
Assert.assertEquals(2, frequencies.length);
Assert.assertArrayEquals(new double[] {21.335, 18.513}, frequencies, 0.005);
}
@Test
public void testFastFourierTransformComplexLong() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
double[] frequencies = fft.getMagnitudes(amplitudes, true);
Assert.assertEquals(4, frequencies.length);
Assert.assertArrayEquals(new double[] {21.335, 18.5132, 14.927, 7.527}, frequencies, 0.005);
}
@Test
public void testFastFourierTransformReal() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
double[] frequencies = fft.getMagnitudes(amplitudes, false);
Assert.assertEquals(4, frequencies.length);
Assert.assertArrayEquals(new double[] {28.8, 2.107, 14.927, 19.874}, frequencies, 0.005);
}
@Test
public void testFastFourierTransformRealOddSize() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5};
double[] frequencies = fft.getMagnitudes(amplitudes, false);
Assert.assertEquals(3, frequencies.length);
Assert.assertArrayEquals(new double[] {24.2, 3.861, 16.876}, frequencies, 0.005);
}
}

View File

@ -1,71 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-data</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-data-codec</artifactId>
<name>datavec-data-codec</name>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.jcodec</groupId>
<artifactId>jcodec</artifactId>
<version>0.1.5</version>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!--
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>ffmpeg-platform</artifactId>
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
</dependency>
-->
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,41 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.format.input;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.BaseInputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.codec.reader.CodecRecordReader;
import java.io.IOException;
/**
* @author Adam Gibson
*/
public class CodecInputFormat extends BaseInputFormat {
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
RecordReader reader = new CodecRecordReader();
reader.initialize(conf, split);
return reader;
}
}

View File

@ -1,144 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public abstract class BaseCodecRecordReader extends FileRecordReader implements SequenceRecordReader {
protected int startFrame = 0;
protected int numFrames = -1;
protected int totalFrames = -1;
protected double framesPerSecond = -1;
protected double videoLength = -1;
protected int rows = 28, cols = 28;
protected boolean ravel = false;
public final static String NAME_SPACE = "org.datavec.codec.reader";
public final static String ROWS = NAME_SPACE + ".rows";
public final static String COLUMNS = NAME_SPACE + ".columns";
public final static String START_FRAME = NAME_SPACE + ".startframe";
public final static String TOTAL_FRAMES = NAME_SPACE + ".frames";
public final static String TIME_SLICE = NAME_SPACE + ".time";
public final static String RAVEL = NAME_SPACE + ".ravel";
public final static String VIDEO_DURATION = NAME_SPACE + ".duration";
@Override
public List<List<Writable>> sequenceRecord() {
URI next = locationsIterator.next();
try (InputStream s = streamCreatorFn.apply(next)){
return loadData(null, s);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException {
return loadData(null, dataInputStream);
}
protected abstract List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException;
@Override
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
setConf(conf);
initialize(split);
}
@Override
public List<Writable> next() {
throw new UnsupportedOperationException("next() not supported for CodecRecordReader (use: sequenceRecord)");
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
throw new UnsupportedOperationException("record(URI,DataInputStream) not supported for CodecRecordReader");
}
@Override
public void setConf(Configuration conf) {
super.setConf(conf);
startFrame = conf.getInt(START_FRAME, 0);
numFrames = conf.getInt(TOTAL_FRAMES, -1);
rows = conf.getInt(ROWS, 28);
cols = conf.getInt(COLUMNS, 28);
framesPerSecond = conf.getFloat(TIME_SLICE, -1);
videoLength = conf.getFloat(VIDEO_DURATION, -1);
ravel = conf.getBoolean(RAVEL, false);
totalFrames = conf.getInt(TOTAL_FRAMES, -1);
}
@Override
public Configuration getConf() {
return super.getConf();
}
@Override
public SequenceRecord nextSequence() {
URI next = locationsIterator.next();
List<List<Writable>> list;
try (InputStream s = streamCreatorFn.apply(next)){
list = loadData(null, s);
} catch (IOException e) {
throw new RuntimeException(e);
}
return new org.datavec.api.records.impl.SequenceRecord(list,
new RecordMetaDataURI(next, CodecRecordReader.class));
}
@Override
public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException {
return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0);
}
@Override
public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
List<SequenceRecord> out = new ArrayList<>();
for (RecordMetaData meta : recordMetaDatas) {
try (InputStream s = streamCreatorFn.apply(meta.getURI())){
List<List<Writable>> list = loadData(null, s);
out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta));
}
}
return out;
}
}

View File

@ -1,138 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.apache.commons.compress.utils.IOUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.ImageLoader;
import org.jcodec.api.FrameGrab;
import org.jcodec.api.JCodecException;
import org.jcodec.common.ByteBufferSeekableByteChannel;
import org.jcodec.common.NIOUtils;
import org.jcodec.common.SeekableByteChannel;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class CodecRecordReader extends BaseCodecRecordReader {
private ImageLoader imageLoader;
@Override
public void setConf(Configuration conf) {
super.setConf(conf);
imageLoader = new ImageLoader(rows, cols);
}
@Override
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
SeekableByteChannel seekableByteChannel;
if (inputStream != null) {
//Reading video from DataInputStream: Need data from this stream in a SeekableByteChannel
//Approach used here: load entire video into memory -> ByteBufferSeekableByteChanel
byte[] data = IOUtils.toByteArray(inputStream);
ByteBuffer bb = ByteBuffer.wrap(data);
seekableByteChannel = new FixedByteBufferSeekableByteChannel(bb);
} else {
seekableByteChannel = NIOUtils.readableFileChannel(file);
}
List<List<Writable>> record = new ArrayList<>();
if (numFrames >= 1) {
FrameGrab fg;
try {
fg = new FrameGrab(seekableByteChannel);
if (startFrame != 0)
fg.seekToFramePrecise(startFrame);
} catch (JCodecException e) {
throw new RuntimeException(e);
}
for (int i = startFrame; i < startFrame + numFrames; i++) {
try {
BufferedImage grab = fg.getFrame();
if (ravel)
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
else
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
} else {
if (framesPerSecond < 1)
throw new IllegalStateException("No frames or frame time intervals specified");
else {
for (double i = 0; i < videoLength; i += framesPerSecond) {
try {
BufferedImage grab = FrameGrab.getFrame(seekableByteChannel, i);
if (ravel)
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
else
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
return record;
}
/** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */
private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel {
private ByteBuffer backing;
public FixedByteBufferSeekableByteChannel(ByteBuffer backing) {
super(backing);
try {
Field f = this.getClass().getSuperclass().getDeclaredField("maxPos");
f.setAccessible(true);
f.set(this, backing.limit());
} catch (Exception e) {
throw new RuntimeException(e);
}
this.backing = backing;
}
@Override
public int read(ByteBuffer dst) throws IOException {
if (!backing.hasRemaining())
return -1;
return super.read(dst);
}
}
}

View File

@ -1,82 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.bytedeco.javacv.FFmpegFrameGrabber;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.conf.Configuration;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
public class NativeCodecRecordReader extends BaseCodecRecordReader {
private OpenCVFrameConverter.ToMat converter;
private NativeImageLoader imageLoader;
@Override
public void setConf(Configuration conf) {
super.setConf(conf);
converter = new OpenCVFrameConverter.ToMat();
imageLoader = new NativeImageLoader(rows, cols);
}
@Override
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
List<List<Writable>> record = new ArrayList<>();
try (FFmpegFrameGrabber fg =
inputStream != null ? new FFmpegFrameGrabber(inputStream) : new FFmpegFrameGrabber(file)) {
if (numFrames >= 1) {
fg.start();
if (startFrame != 0)
fg.setFrameNumber(startFrame);
for (int i = startFrame; i < startFrame + numFrames; i++) {
Frame grab = fg.grabImage();
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
}
} else {
if (framesPerSecond < 1)
throw new IllegalStateException("No frames or frame time intervals specified");
else {
fg.start();
for (double i = 0; i < videoLength; i += framesPerSecond) {
fg.setTimestamp(Math.round(i * 1000000L));
Frame grab = fg.grabImage();
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
}
}
}
}
return record;
}
}

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.codec.reader";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,212 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.ArrayWritable;
import org.datavec.api.writable.Writable;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Iterator;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author Adam Gibson
*/
public class CodecReaderTest {
@Test
public void testCodecReader() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new CodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
// System.out.println(record.size());
Iterator<List<Writable>> it = record.iterator();
List<Writable> first = it.next();
// System.out.println(first);
//Expected size: 80x46x3
assertEquals(1, first.size());
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
}
@Test
public void testCodecReaderMeta() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new CodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
assertEquals(500, record.size()); //500 frames
reader.reset();
SequenceRecord seqR = reader.nextSequence();
assertEquals(record, seqR.getSequenceRecord());
RecordMetaData meta = seqR.getMetaData();
// System.out.println(meta);
assertTrue(meta.getURI().toString().endsWith(file.getName()));
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
assertEquals(seqR, fromMeta);
}
@Test
public void testViaDataInputStream() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new CodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
Configuration conf2 = new Configuration(conf);
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> expected = reader.sequenceRecord();
SequenceRecordReader reader2 = new CodecRecordReader();
reader2.setConf(conf2);
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
assertEquals(expected, actual);
}
@Ignore
@Test
public void testNativeCodecReader() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new NativeCodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
// System.out.println(record.size());
Iterator<List<Writable>> it = record.iterator();
List<Writable> first = it.next();
// System.out.println(first);
//Expected size: 80x46x3
assertEquals(1, first.size());
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
}
@Ignore
@Test
public void testNativeCodecReaderMeta() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new NativeCodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
assertEquals(500, record.size()); //500 frames
reader.reset();
SequenceRecord seqR = reader.nextSequence();
assertEquals(record, seqR.getSequenceRecord());
RecordMetaData meta = seqR.getMetaData();
// System.out.println(meta);
assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4"));
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
assertEquals(seqR, fromMeta);
}
@Ignore
@Test
public void testNativeViaDataInputStream() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new NativeCodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
Configuration conf2 = new Configuration(conf);
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> expected = reader.sequenceRecord();
SequenceRecordReader reader2 = new NativeCodecRecordReader();
reader2.setConf(conf2);
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
assertEquals(expected, actual);
}
}

View File

@ -1,77 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-data</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-data-nlp</artifactId>
<name>datavec-data-nlp</name>
<properties>
<cleartk.version>2.0.0</cleartk.version>
</properties>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.cleartk</groupId>
<artifactId>cleartk-snowball</artifactId>
<version>${cleartk.version}</version>
</dependency>
<dependency>
<groupId>org.cleartk</groupId>
<artifactId>cleartk-opennlp-tools</artifactId>
<version>${cleartk.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,237 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSTaggerME;
import opennlp.uima.postag.POSModelResource;
import opennlp.uima.postag.POSModelResourceImpl;
import opennlp.uima.util.AnnotationComboIterator;
import opennlp.uima.util.AnnotationIteratorPair;
import opennlp.uima.util.AnnotatorUtil;
import opennlp.uima.util.UimaUtil;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.Feature;
import org.apache.uima.cas.Type;
import org.apache.uima.cas.TypeSystem;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.component.CasAnnotator_ImplBase;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.ExternalResourceFactory;
import org.apache.uima.resource.ResourceAccessException;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.Level;
import org.apache.uima.util.Logger;
import org.cleartk.token.type.Sentence;
import org.cleartk.token.type.Token;
import org.datavec.nlp.movingwindow.Util;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
public class PoStagger extends CasAnnotator_ImplBase {
static {
//UIMA logging
Util.disableLogging();
}
private POSTaggerME posTagger;
private Type sentenceType;
private Type tokenType;
private Feature posFeature;
private Feature probabilityFeature;
private UimaContext context;
private Logger logger;
/**
* Initializes a new instance.
*
* Note: Use {@link #initialize(org.apache.uima.UimaContext) } to initialize this instance. Not use the
* constructor.
*/
public PoStagger() {
// must not be implemented !
}
/**
* Initializes the current instance with the given context.
*
* Note: Do all initialization in this method, do not use the constructor.
*/
@Override
public void initialize(UimaContext context) throws ResourceInitializationException {
super.initialize(context);
this.context = context;
this.logger = context.getLogger();
if (this.logger.isLoggable(Level.INFO)) {
this.logger.log(Level.INFO, "Initializing the OpenNLP " + "Part of Speech annotator.");
}
POSModel model;
try {
POSModelResource modelResource = (POSModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
model = modelResource.getModel();
} catch (ResourceAccessException e) {
throw new ResourceInitializationException(e);
}
Integer beamSize = AnnotatorUtil.getOptionalIntegerParameter(context, UimaUtil.BEAM_SIZE_PARAMETER);
if (beamSize == null)
beamSize = POSTaggerME.DEFAULT_BEAM_SIZE;
this.posTagger = new POSTaggerME(model, beamSize, 0);
}
/**
* Initializes the type system.
*/
@Override
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
// sentence type
this.sentenceType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
UimaUtil.SENTENCE_TYPE_PARAMETER);
// token type
this.tokenType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
UimaUtil.TOKEN_TYPE_PARAMETER);
// pos feature
this.posFeature = AnnotatorUtil.getRequiredFeatureParameter(this.context, this.tokenType,
UimaUtil.POS_FEATURE_PARAMETER, CAS.TYPE_NAME_STRING);
this.probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(this.context, this.tokenType,
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
}
/**
* Performs pos-tagging on the given tcas object.
*/
@Override
public synchronized void process(CAS tcas) {
final AnnotationComboIterator comboIterator =
new AnnotationComboIterator(tcas, this.sentenceType, this.tokenType);
for (AnnotationIteratorPair annotationIteratorPair : comboIterator) {
final List<AnnotationFS> sentenceTokenAnnotationList = new LinkedList<AnnotationFS>();
final List<String> sentenceTokenList = new LinkedList<String>();
for (AnnotationFS tokenAnnotation : annotationIteratorPair.getSubIterator()) {
sentenceTokenAnnotationList.add(tokenAnnotation);
sentenceTokenList.add(tokenAnnotation.getCoveredText());
}
final List<String> posTags = this.posTagger.tag(sentenceTokenList);
double posProbabilities[] = null;
if (this.probabilityFeature != null) {
posProbabilities = this.posTagger.probs();
}
final Iterator<String> posTagIterator = posTags.iterator();
final Iterator<AnnotationFS> sentenceTokenIterator = sentenceTokenAnnotationList.iterator();
int index = 0;
while (posTagIterator.hasNext() && sentenceTokenIterator.hasNext()) {
final String posTag = posTagIterator.next();
final AnnotationFS tokenAnnotation = sentenceTokenIterator.next();
tokenAnnotation.setStringValue(this.posFeature, posTag);
if (posProbabilities != null) {
tokenAnnotation.setDoubleValue(this.posFeature, posProbabilities[index]);
}
index++;
}
// log tokens with pos
if (this.logger.isLoggable(Level.FINER)) {
final StringBuilder sentenceWithPos = new StringBuilder();
sentenceWithPos.append("\"");
for (final Iterator<AnnotationFS> it = sentenceTokenAnnotationList.iterator(); it.hasNext();) {
final AnnotationFS token = it.next();
sentenceWithPos.append(token.getCoveredText());
sentenceWithPos.append('\\');
sentenceWithPos.append(token.getStringValue(this.posFeature));
sentenceWithPos.append(' ');
}
// delete last whitespace
if (sentenceWithPos.length() > 1) // not 0 because it contains already the " char
sentenceWithPos.setLength(sentenceWithPos.length() - 1);
sentenceWithPos.append("\"");
this.logger.log(Level.FINER, sentenceWithPos.toString());
}
}
}
/**
* Releases allocated resources.
*/
@Override
public void destroy() {
this.posTagger = null;
}
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
String modelPath = String.format("/models/%s-pos-maxent.bin", languageCode);
return AnalysisEngineFactory.createEngineDescription(PoStagger.class, UimaUtil.MODEL_PARAMETER,
ExternalResourceFactory.createExternalResourceDescription(POSModelResourceImpl.class,
PoStagger.class.getResource(modelPath).toString()),
UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), UimaUtil.TOKEN_TYPE_PARAMETER,
Token.class.getName(), UimaUtil.POS_FEATURE_PARAMETER, "pos");
}
}

View File

@ -1,52 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.util.ParamUtil;
import org.datavec.nlp.movingwindow.Util;
public class SentenceAnnotator extends org.cleartk.opennlp.tools.SentenceAnnotator {
static {
//UIMA logging
Util.disableLogging();
}
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
return AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.class, PARAM_SENTENCE_MODEL_PATH,
ParamUtil.getParameterValue(PARAM_SENTENCE_MODEL_PATH, "/models/en-sent.bin"),
PARAM_WINDOW_CLASS_NAMES, ParamUtil.getParameterValue(PARAM_WINDOW_CLASS_NAMES, null));
}
@Override
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
super.process(jCas);
}
}

View File

@ -1,58 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.snowball.SnowballStemmer;
import org.cleartk.token.type.Token;
public class StemmerAnnotator extends SnowballStemmer<Token> {
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
return getDescription("English");
}
public static AnalysisEngineDescription getDescription(String language) throws ResourceInitializationException {
return AnalysisEngineFactory.createEngineDescription(StemmerAnnotator.class, SnowballStemmer.PARAM_STEMMER_NAME,
language);
}
@SuppressWarnings("unchecked")
@Override
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
super.process(jCas);
}
@Override
public void setStem(Token token, String stem) {
token.setStem(stem);
}
}

View File

@ -1,70 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import opennlp.uima.tokenize.TokenizerModelResourceImpl;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.ExternalResourceFactory;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.opennlp.tools.Tokenizer;
import org.cleartk.token.type.Sentence;
import org.cleartk.token.type.Token;
import org.datavec.nlp.movingwindow.Util;
import org.datavec.nlp.tokenization.tokenizer.ConcurrentTokenizer;
/**
* Overrides OpenNLP tokenizer to be thread safe
*/
public class TokenizerAnnotator extends Tokenizer {
static {
//UIMA logging
Util.disableLogging();
}
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
String modelPath = String.format("/models/%s-token.bin", languageCode);
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
ConcurrentTokenizer.class.getResource(modelPath).toString()),
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
}
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
String modelPath = String.format("/models/%s-token.bin", "en");
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
ConcurrentTokenizer.class.getResource(modelPath).toString()),
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
}
}

View File

@ -1,41 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.input;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.BaseInputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.nlp.reader.TfidfRecordReader;
import java.io.IOException;
/**
* @author Adam Gibson
*/
public class TextInputFormat extends BaseInputFormat {
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
RecordReader reader = new TfidfRecordReader();
reader.initialize(conf, split);
return reader;
}
}

View File

@ -1,148 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.metadata;
import org.nd4j.common.primitives.Counter;
import org.datavec.api.conf.Configuration;
import org.datavec.nlp.vectorizer.TextVectorizer;
import org.nd4j.common.util.MathUtils;
import org.nd4j.common.util.Index;
public class DefaultVocabCache implements VocabCache {
private Counter<String> wordFrequencies = new Counter<>();
private Counter<String> docFrequencies = new Counter<>();
private int minWordFrequency;
private Index vocabWords = new Index();
private double numDocs = 0;
/**
* Instantiate with a given min word frequency
* @param minWordFrequency
*/
public DefaultVocabCache(int minWordFrequency) {
this.minWordFrequency = minWordFrequency;
}
/*
* Constructor for use with initialize()
*/
public DefaultVocabCache() {
}
@Override
public void incrementNumDocs(double by) {
numDocs += by;
}
@Override
public double numDocs() {
return numDocs;
}
@Override
public String wordAt(int i) {
return vocabWords.get(i).toString();
}
@Override
public int wordIndex(String word) {
return vocabWords.indexOf(word);
}
@Override
public void initialize(Configuration conf) {
minWordFrequency = conf.getInt(TextVectorizer.MIN_WORD_FREQUENCY, 5);
}
@Override
public double wordFrequency(String word) {
return wordFrequencies.getCount(word);
}
@Override
public int minWordFrequency() {
return minWordFrequency;
}
@Override
public Index vocabWords() {
return vocabWords;
}
@Override
public void incrementDocCount(String word) {
incrementDocCount(word, 1.0);
}
@Override
public void incrementDocCount(String word, double by) {
docFrequencies.incrementCount(word, by);
}
@Override
public void incrementCount(String word) {
incrementCount(word, 1.0);
}
@Override
public void incrementCount(String word, double by) {
wordFrequencies.incrementCount(word, by);
if (wordFrequencies.getCount(word) >= minWordFrequency && vocabWords.indexOf(word) < 0)
vocabWords.add(word);
}
@Override
public double idf(String word) {
return docFrequencies.getCount(word);
}
@Override
public double tfidf(String word, double frequency, boolean smoothIdf) {
double tf = tf((int) frequency);
double docFreq = docFrequencies.getCount(word);
double idf = idf(numDocs, docFreq, smoothIdf);
double tfidf = MathUtils.tfidf(tf, idf);
return tfidf;
}
public double idf(double totalDocs, double numTimesWordAppearedInADocument, boolean smooth) {
if(smooth){
return Math.log((1 + totalDocs) / (1 + numTimesWordAppearedInADocument)) + 1.0;
} else {
return Math.log(totalDocs / numTimesWordAppearedInADocument) + 1.0;
}
}
public static double tf(int count) {
return count;
}
public int getMinWordFrequency() {
return minWordFrequency;
}
public void setMinWordFrequency(int minWordFrequency) {
this.minWordFrequency = minWordFrequency;
}
}

View File

@ -1,121 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.metadata;
import org.datavec.api.conf.Configuration;
import org.nd4j.common.util.Index;
public interface VocabCache {
/**
* Increment the number of documents
* @param by
*/
void incrementNumDocs(double by);
/**
* Number of documents
* @return the number of documents
*/
double numDocs();
/**
* Returns a word in the vocab at a particular index
* @param i the index to get
* @return the word at that index in the vocab
*/
String wordAt(int i);
int wordIndex(String word);
/**
* Configuration for initializing
* @param conf the configuration to initialize with
*/
void initialize(Configuration conf);
/**
* Get the word frequency for a word
* @param word the word to get frequency for
* @return the frequency for a given word
*/
double wordFrequency(String word);
/**
* The min word frequency
* needed to be included in the vocab
* (default 5)
* @return the min word frequency to
* be included in the vocab
*/
int minWordFrequency();
/**
* All of the vocab words (ordered)
* note that these are not all the possible tokens
* @return the list of vocab words
*/
Index vocabWords();
/**
* Increment the doc count for a word by 1
* @param word the word to increment the count for
*/
void incrementDocCount(String word);
/**
* Increment the document count for a particular word
* @param word the word to increment the count for
* @param by the amount to increment by
*/
void incrementDocCount(String word, double by);
/**
* Increment a word count by 1
* @param word the word to increment the count for
*/
void incrementCount(String word);
/**
* Increment count for a word
* @param word the word to increment the count for
* @param by the amount to increment by
*/
void incrementCount(String word, double by);
/**
* Number of documents word has occurred in
* @param word the word to get the idf for
*/
double idf(String word);
/**
* Calculate the tfidf of the word given the document frequency
* @param word the word to get frequency for
* @param frequency the frequency
* @return the tfidf for a word
*/
double tfidf(String word, double frequency, boolean smoothIdf);
}

View File

@ -1,125 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.collection.MultiDimensionalMap;
import org.nd4j.common.primitives.Pair;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import java.util.ArrayList;
import java.util.List;
public class ContextLabelRetriever {
private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>";
private static String END_LABEL = "</([A-Za-z]+|\\d+)>";
private ContextLabelRetriever() {}
/**
* Returns a stripped sentence with the indices of words
* with certain kinds of labels.
*
* @param sentence the sentence to process
* @return a pair of a post processed sentence
* with labels stripped and the spans of
* the labels
*/
public static Pair<String, MultiDimensionalMap<Integer, Integer, String>> stringWithLabels(String sentence,
TokenizerFactory tokenizerFactory) {
MultiDimensionalMap<Integer, Integer, String> map = MultiDimensionalMap.newHashBackedMap();
Tokenizer t = tokenizerFactory.create(sentence);
List<String> currTokens = new ArrayList<>();
String currLabel = null;
String endLabel = null;
List<Pair<String, List<String>>> tokensWithSameLabel = new ArrayList<>();
while (t.hasMoreTokens()) {
String token = t.nextToken();
if (token.matches(BEGIN_LABEL)) {
currLabel = token;
//no labels; add these as NONE and begin the new label
if (!currTokens.isEmpty()) {
tokensWithSameLabel.add(new Pair<>("NONE", (List<String>) new ArrayList<>(currTokens)));
currTokens.clear();
}
} else if (token.matches(END_LABEL)) {
if (currLabel == null)
throw new IllegalStateException("Found an ending label with no matching begin label");
endLabel = token;
} else
currTokens.add(token);
if (currLabel != null && endLabel != null) {
currLabel = currLabel.replaceAll("[<>/]", "");
endLabel = endLabel.replaceAll("[<>/]", "");
Preconditions.checkState(!currLabel.isEmpty(), "Current label is empty!");
Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!");
Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s",
currLabel, endLabel);
tokensWithSameLabel.add(new Pair<>(currLabel, (List<String>) new ArrayList<>(currTokens)));
currTokens.clear();
//clear out the tokens
currLabel = null;
endLabel = null;
}
}
//no labels; add these as NONE and begin the new label
if (!currTokens.isEmpty()) {
tokensWithSameLabel.add(new Pair<>("none", (List<String>) new ArrayList<>(currTokens)));
currTokens.clear();
}
//now join the output
StringBuilder strippedSentence = new StringBuilder();
for (Pair<String, List<String>> tokensWithLabel : tokensWithSameLabel) {
String joinedSentence = StringUtils.join(tokensWithLabel.getSecond(), " ");
//spaces between separate parts of the sentence
if (!(strippedSentence.length() < 1))
strippedSentence.append(" ");
strippedSentence.append(joinedSentence);
int begin = strippedSentence.toString().indexOf(joinedSentence);
int end = begin + joinedSentence.length();
map.put(begin, end, tokensWithLabel.getFirst());
}
return new Pair<>(strippedSentence.toString(), map);
}
}

View File

@ -1,60 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.nd4j.common.primitives.Counter;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
public class Util {
/**
* Returns a thread safe counter
*
* @return
*/
public static Counter<String> parallelCounter() {
return new Counter<>();
}
public static boolean matchesAnyStopWord(List<String> stopWords, String word) {
for (String s : stopWords)
if (s.equalsIgnoreCase(word))
return true;
return false;
}
public static Level disableLogging() {
Logger logger = Logger.getLogger("org.apache.uima");
while (logger.getLevel() == null) {
logger = logger.getParent();
}
Level level = logger.getLevel();
logger.setLevel(Level.OFF);
return level;
}
}

View File

@ -1,177 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.apache.commons.lang3.StringUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class Window implements Serializable {
/**
*
*/
private static final long serialVersionUID = 6359906393699230579L;
private List<String> words;
private String label = "NONE";
private boolean beginLabel;
private boolean endLabel;
private int median;
private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>";
private static String END_LABEL = "</([A-Z]+|\\d+)>";
private int begin, end;
/**
* Creates a window with a context of size 3
* @param words a collection of strings of size 3
*/
public Window(Collection<String> words, int begin, int end) {
this(words, 5, begin, end);
}
public String asTokens() {
return StringUtils.join(words, " ");
}
/**
* Initialize a window with the given size
* @param words the words to use
* @param windowSize the size of the window
* @param begin the begin index for the window
* @param end the end index for the window
*/
public Window(Collection<String> words, int windowSize, int begin, int end) {
if (words == null)
throw new IllegalArgumentException("Words must be a list of size 3");
this.words = new ArrayList<>(words);
int windowSize1 = windowSize;
this.begin = begin;
this.end = end;
initContext();
}
private void initContext() {
int median = (int) Math.floor(words.size() / 2);
List<String> begin = words.subList(0, median);
List<String> after = words.subList(median + 1, words.size());
for (String s : begin) {
if (s.matches(BEGIN_LABEL)) {
this.label = s.replaceAll("(<|>)", "").replace("/", "");
beginLabel = true;
} else if (s.matches(END_LABEL)) {
endLabel = true;
this.label = s.replaceAll("(<|>|/)", "").replace("/", "");
}
}
for (String s1 : after) {
if (s1.matches(BEGIN_LABEL)) {
this.label = s1.replaceAll("(<|>)", "").replace("/", "");
beginLabel = true;
}
if (s1.matches(END_LABEL)) {
endLabel = true;
this.label = s1.replaceAll("(<|>)", "");
}
}
this.median = median;
}
@Override
public String toString() {
return words.toString();
}
public List<String> getWords() {
return words;
}
public void setWords(List<String> words) {
this.words = words;
}
public String getWord(int i) {
return words.get(i);
}
public String getFocusWord() {
return words.get(median);
}
public boolean isBeginLabel() {
return !label.equals("NONE") && beginLabel;
}
public boolean isEndLabel() {
return !label.equals("NONE") && endLabel;
}
public String getLabel() {
return label.replace("/", "");
}
public int getWindowSize() {
return words.size();
}
public int getMedian() {
return median;
}
public void setLabel(String label) {
this.label = label;
}
public int getBegin() {
return begin;
}
public void setBegin(int begin) {
this.begin = begin;
}
public int getEnd() {
return end;
}
public void setEnd(int end) {
this.end = end;
}
}

View File

@ -1,188 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.apache.commons.lang3.StringUtils;
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
public class Windows {
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(InputStream words, int windowSize) {
Tokenizer tokenizer = new DefaultStreamTokenizer(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param tokenizerFactory tokenizer factory to use
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(InputStream words, TokenizerFactory tokenizerFactory, int windowSize) {
Tokenizer tokenizer = tokenizerFactory.create(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
if (list.isEmpty())
throw new IllegalStateException("No tokens found for windows");
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words, int windowSize) {
StringTokenizer tokenizer = new StringTokenizer(words);
List<String> list = new ArrayList<String>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param tokenizerFactory tokenizer factory to use
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory, int windowSize) {
Tokenizer tokenizer = tokenizerFactory.create(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
if (list.isEmpty())
throw new IllegalStateException("No tokens found for windows");
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words) {
StringTokenizer tokenizer = new StringTokenizer(words);
List<String> list = new ArrayList<String>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, 5);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param tokenizerFactory tokenizer factory to use
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory) {
Tokenizer tokenizer = tokenizerFactory.create(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, 5);
}
/**
* Creates a sliding window from text
* @param windowSize the window size to use
* @param wordPos the position of the word to center
* @param sentence the sentence to createComplex a window for
* @return a window based on the given sentence
*/
public static Window windowForWordInPosition(int windowSize, int wordPos, List<String> sentence) {
List<String> window = new ArrayList<>();
List<String> onlyTokens = new ArrayList<>();
int contextSize = (int) Math.floor((windowSize - 1) / 2);
for (int i = wordPos - contextSize; i <= wordPos + contextSize; i++) {
if (i < 0)
window.add("<s>");
else if (i >= sentence.size())
window.add("</s>");
else {
onlyTokens.add(sentence.get(i));
window.add(sentence.get(i));
}
}
String wholeSentence = StringUtils.join(sentence);
String window2 = StringUtils.join(onlyTokens);
int begin = wholeSentence.indexOf(window2);
int end = begin + window2.length();
return new Window(window, begin, end);
}
/**
* Constructs a list of window of size windowSize
* @param words the words to construct windows from
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(List<String> words, int windowSize) {
List<Window> ret = new ArrayList<>();
for (int i = 0; i < words.size(); i++)
ret.add(windowForWordInPosition(windowSize, i, words));
return ret;
}
}

View File

@ -1,189 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.reader;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.vector.Vectorizer;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.vectorizer.TfidfVectorizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.util.*;
public class TfidfRecordReader extends FileRecordReader {
private TfidfVectorizer tfidfVectorizer;
private List<Record> records = new ArrayList<>();
private Iterator<Record> recordIter;
private int numFeatures;
private boolean initialized = false;
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
initialize(new Configuration(), split);
}
@Override
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
super.initialize(conf, split);
//train a new one since it hasn't been specified
if (tfidfVectorizer == null) {
tfidfVectorizer = new TfidfVectorizer();
tfidfVectorizer.initialize(conf);
//clear out old strings
records.clear();
INDArray ret = tfidfVectorizer.fitTransform(this, new Vectorizer.RecordCallBack() {
@Override
public void onRecord(Record fullRecord) {
records.add(fullRecord);
}
});
//cache the number of features used for each document
numFeatures = ret.columns();
recordIter = records.iterator();
} else {
records = new ArrayList<>();
//the record reader has 2 phases, we are skipping the
//document frequency phase and just using the super() to get the file contents
//and pass it to the already existing vectorizer.
while (super.hasNext()) {
Record fileContents = super.nextRecord();
INDArray transform = tfidfVectorizer.transform(fileContents);
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
if (appendLabel)
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
records.add(record);
}
recordIter = records.iterator();
}
this.initialized = true;
}
@Override
public void reset() {
if (inputSplit == null)
throw new UnsupportedOperationException("Cannot reset without first initializing");
recordIter = records.iterator();
}
@Override
public Record nextRecord() {
if (recordIter == null)
return super.nextRecord();
return recordIter.next();
}
@Override
public List<Writable> next() {
return nextRecord().getRecord();
}
@Override
public boolean hasNext() {
//we aren't done vectorizing yet
if (recordIter == null)
return super.hasNext();
return recordIter.hasNext();
}
@Override
public void close() throws IOException {
}
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public Configuration getConf() {
return conf;
}
public TfidfVectorizer getTfidfVectorizer() {
return tfidfVectorizer;
}
public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) {
if (initialized) {
throw new IllegalArgumentException(
"Setting TfidfVectorizer after TfidfRecordReader initialization doesn't have an effect");
}
this.tfidfVectorizer = tfidfVectorizer;
}
public int getNumFeatures() {
return numFeatures;
}
public void shuffle() {
this.shuffle(new Random());
}
public void shuffle(Random random) {
Collections.shuffle(this.records, random);
this.reset();
}
@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0);
}
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
List<Record> out = new ArrayList<>();
for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) {
INDArray transform = tfidfVectorizer.transform(fileContents);
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
if (appendLabel)
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
out.add(record);
}
return out;
}
}

View File

@ -1,44 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.stopwords;
import org.apache.commons.io.IOUtils;
import java.io.IOException;
import java.util.List;
public class StopWords {
private static List<String> stopWords;
@SuppressWarnings("unchecked")
public static List<String> getStopWords() {
try {
if (stopWords == null)
stopWords = IOUtils.readLines(StopWords.class.getResourceAsStream("/stopwords"));
} catch (IOException e) {
throw new RuntimeException(e);
}
return stopWords;
}
}

View File

@ -1,125 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import opennlp.tools.tokenize.TokenizerME;
import opennlp.tools.tokenize.TokenizerModel;
import opennlp.tools.util.Span;
import opennlp.uima.tokenize.AbstractTokenizer;
import opennlp.uima.tokenize.TokenizerModelResource;
import opennlp.uima.util.AnnotatorUtil;
import opennlp.uima.util.UimaUtil;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.Feature;
import org.apache.uima.cas.TypeSystem;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.resource.ResourceAccessException;
import org.apache.uima.resource.ResourceInitializationException;
public class ConcurrentTokenizer extends AbstractTokenizer {
/**
* The OpenNLP tokenizer.
*/
private TokenizerME tokenizer;
private Feature probabilityFeature;
@Override
public synchronized void process(CAS cas) throws AnalysisEngineProcessException {
super.process(cas);
}
/**
* Initializes a new instance.
*
* Note: Use {@link #initialize(UimaContext) } to initialize
* this instance. Not use the constructor.
*/
public ConcurrentTokenizer() {
super("OpenNLP Tokenizer");
// must not be implemented !
}
/**
* Initializes the current instance with the given context.
*
* Note: Do all initialization in this method, do not use the constructor.
*/
public void initialize(UimaContext context) throws ResourceInitializationException {
super.initialize(context);
TokenizerModel model;
try {
TokenizerModelResource modelResource =
(TokenizerModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
model = modelResource.getModel();
} catch (ResourceAccessException e) {
throw new ResourceInitializationException(e);
}
tokenizer = new TokenizerME(model);
}
/**
* Initializes the type system.
*/
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
super.typeSystemInit(typeSystem);
probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(context, tokenType,
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
}
@Override
protected Span[] tokenize(CAS cas, AnnotationFS sentence) {
return tokenizer.tokenizePos(sentence.getCoveredText());
}
@Override
protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) {
// if interest
if (probabilityFeature != null) {
double tokenProbabilties[] = tokenizer.getTokenProbabilities();
for (int i = 0; i < tokenAnnotations.length; i++) {
tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]);
}
}
}
/**
* Releases allocated resources.
*/
public void destroy() {
// dereference model to allow garbage collection
tokenizer = null;
}
}

View File

@ -1,107 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
/**
* Tokenizer based on the {@link java.io.StreamTokenizer}
* @author Adam Gibson
*
*/
public class DefaultStreamTokenizer implements Tokenizer {
private StreamTokenizer streamTokenizer;
private TokenPreProcess tokenPreProcess;
public DefaultStreamTokenizer(InputStream is) {
Reader r = new BufferedReader(new InputStreamReader(is));
streamTokenizer = new StreamTokenizer(r);
}
@Override
public boolean hasMoreTokens() {
if (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
try {
streamTokenizer.nextToken();
} catch (IOException e1) {
throw new RuntimeException(e1);
}
}
return streamTokenizer.ttype != StreamTokenizer.TT_EOF && streamTokenizer.ttype != -1;
}
@Override
public int countTokens() {
return getTokens().size();
}
@Override
public String nextToken() {
StringBuilder sb = new StringBuilder();
if (streamTokenizer.ttype == StreamTokenizer.TT_WORD) {
sb.append(streamTokenizer.sval);
} else if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
sb.append(streamTokenizer.nval);
} else if (streamTokenizer.ttype == StreamTokenizer.TT_EOL) {
try {
while (streamTokenizer.ttype == StreamTokenizer.TT_EOL)
streamTokenizer.nextToken();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
else if (hasMoreTokens())
return nextToken();
String ret = sb.toString();
if (tokenPreProcess != null)
ret = tokenPreProcess.preProcess(ret);
return ret;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
this.tokenPreProcess = tokenPreProcessor;
}
}

View File

@ -1,75 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
/**
* Default tokenizer
* @author Adam Gibson
*/
public class DefaultTokenizer implements Tokenizer {
public DefaultTokenizer(String tokens) {
tokenizer = new StringTokenizer(tokens);
}
private StringTokenizer tokenizer;
private TokenPreProcess tokenPreProcess;
@Override
public boolean hasMoreTokens() {
return tokenizer.hasMoreTokens();
}
@Override
public int countTokens() {
return tokenizer.countTokens();
}
@Override
public String nextToken() {
String base = tokenizer.nextToken();
if (tokenPreProcess != null)
base = tokenPreProcess.preProcess(base);
return base;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
this.tokenPreProcess = tokenPreProcessor;
}
}

View File

@ -1,136 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.cas.CAS;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.util.JCasUtil;
import org.cleartk.token.type.Sentence;
import org.cleartk.token.type.Token;
import org.datavec.nlp.annotator.PoStagger;
import org.datavec.nlp.annotator.SentenceAnnotator;
import org.datavec.nlp.annotator.StemmerAnnotator;
import org.datavec.nlp.annotator.TokenizerAnnotator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class PosUimaTokenizer implements Tokenizer {
private static AnalysisEngine engine;
private List<String> tokens;
private Collection<String> allowedPosTags;
private int index;
private static CAS cas;
public PosUimaTokenizer(String tokens, AnalysisEngine engine, Collection<String> allowedPosTags) {
if (engine == null)
PosUimaTokenizer.engine = engine;
this.allowedPosTags = allowedPosTags;
this.tokens = new ArrayList<>();
try {
if (cas == null)
cas = engine.newCAS();
cas.reset();
cas.setDocumentText(tokens);
PosUimaTokenizer.engine.process(cas);
for (Sentence s : JCasUtil.select(cas.getJCas(), Sentence.class)) {
for (Token t : JCasUtil.selectCovered(Token.class, s)) {
//add NONE for each invalid token
if (valid(t))
if (t.getLemma() != null)
this.tokens.add(t.getLemma());
else if (t.getStem() != null)
this.tokens.add(t.getStem());
else
this.tokens.add(t.getCoveredText());
else
this.tokens.add("NONE");
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private boolean valid(Token token) {
String check = token.getCoveredText();
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
return false;
else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos()))
return false;
return true;
}
@Override
public boolean hasMoreTokens() {
return index < tokens.size();
}
@Override
public int countTokens() {
return tokens.size();
}
@Override
public String nextToken() {
String ret = tokens.get(index);
index++;
return ret;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<String>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
public static AnalysisEngine defaultAnalysisEngine() {
try {
return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription(
SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(),
PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English")));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
// TODO Auto-generated method stub
}
}

View File

@ -1,34 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
public interface TokenPreProcess {
/**
* Pre process a token
* @param token the token to pre process
* @return the preprocessed token
*/
String preProcess(String token);
}

View File

@ -1,61 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import java.util.List;
public interface Tokenizer {
/**
* An iterator for tracking whether
* more tokens are left in the iterator not
* @return whether there is anymore tokens
* to iterate over
*/
boolean hasMoreTokens();
/**
* The number of tokens in the tokenizer
* @return the number of tokens
*/
int countTokens();
/**
* The next token (word usually) in the string
* @return the next token in the string if any
*/
String nextToken();
/**
* Returns a list of all the tokens
* @return a list of all the tokens
*/
List<String> getTokens();
/**
* Set the token pre process
* @param tokenPreProcessor the token pre processor to set
*/
void setTokenPreProcessor(TokenPreProcess tokenPreProcessor);
}

View File

@ -1,123 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import org.apache.uima.cas.CAS;
import org.apache.uima.fit.util.JCasUtil;
import org.cleartk.token.type.Token;
import org.datavec.nlp.uima.UimaResource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
* Tokenizer based on the passed in analysis engine
* @author Adam Gibson
*
*/
public class UimaTokenizer implements Tokenizer {
private List<String> tokens;
private int index;
private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class);
private boolean checkForLabel;
private TokenPreProcess tokenPreProcessor;
public UimaTokenizer(String tokens, UimaResource resource, boolean checkForLabel) {
this.checkForLabel = checkForLabel;
this.tokens = new ArrayList<>();
try {
CAS cas = resource.process(tokens);
Collection<Token> tokenList = JCasUtil.select(cas.getJCas(), Token.class);
for (Token t : tokenList) {
if (!checkForLabel || valid(t.getCoveredText()))
if (t.getLemma() != null)
this.tokens.add(t.getLemma());
else if (t.getStem() != null)
this.tokens.add(t.getStem());
else
this.tokens.add(t.getCoveredText());
}
resource.release(cas);
} catch (Exception e) {
log.error("",e);
throw new RuntimeException(e);
}
}
private boolean valid(String check) {
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
return false;
return true;
}
@Override
public boolean hasMoreTokens() {
return index < tokens.size();
}
@Override
public int countTokens() {
return tokens.size();
}
@Override
public String nextToken() {
String ret = tokens.get(index);
index++;
if (tokenPreProcessor != null) {
ret = tokenPreProcessor.preProcess(ret);
}
return ret;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
this.tokenPreProcessor = tokenPreProcessor;
}
}

View File

@ -1,47 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
/**
* Gets rid of endings:
*
* ed,ing, ly, s, .
* @author Adam Gibson
*/
public class EndingPreProcessor implements TokenPreProcess {
@Override
public String preProcess(String token) {
if (token.endsWith("s") && !token.endsWith("ss"))
token = token.substring(0, token.length() - 1);
if (token.endsWith("."))
token = token.substring(0, token.length() - 1);
if (token.endsWith("ed"))
token = token.substring(0, token.length() - 2);
if (token.endsWith("ing"))
token = token.substring(0, token.length() - 3);
if (token.endsWith("ly"))
token = token.substring(0, token.length() - 2);
return token;
}
}

View File

@ -1,30 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
public class LowerCasePreProcessor implements TokenPreProcess {
@Override
public String preProcess(String token) {
return token.toLowerCase();
}
}

View File

@ -1,60 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
import org.datavec.nlp.tokenization.tokenizer.DefaultTokenizer;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import java.io.InputStream;
/**
* Default tokenizer based on string tokenizer or stream tokenizer
* @author Adam Gibson
*/
public class DefaultTokenizerFactory implements TokenizerFactory {
private TokenPreProcess tokenPreProcess;
@Override
public Tokenizer create(String toTokenize) {
DefaultTokenizer t = new DefaultTokenizer(toTokenize);
t.setTokenPreProcessor(tokenPreProcess);
return t;
}
@Override
public Tokenizer create(InputStream toTokenize) {
Tokenizer t = new DefaultStreamTokenizer(toTokenize);
t.setTokenPreProcessor(tokenPreProcess);
return t;
}
@Override
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
this.tokenPreProcess = preProcessor;
}
}

View File

@ -1,85 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.datavec.nlp.annotator.PoStagger;
import org.datavec.nlp.annotator.SentenceAnnotator;
import org.datavec.nlp.annotator.StemmerAnnotator;
import org.datavec.nlp.annotator.TokenizerAnnotator;
import org.datavec.nlp.tokenization.tokenizer.PosUimaTokenizer;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import java.io.InputStream;
import java.util.Collection;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
public class PosUimaTokenizerFactory implements TokenizerFactory {
private AnalysisEngine tokenizer;
private Collection<String> allowedPoSTags;
private TokenPreProcess tokenPreProcess;
public PosUimaTokenizerFactory(Collection<String> allowedPoSTags) {
this(defaultAnalysisEngine(), allowedPoSTags);
}
public PosUimaTokenizerFactory(AnalysisEngine tokenizer, Collection<String> allowedPosTags) {
this.tokenizer = tokenizer;
this.allowedPoSTags = allowedPosTags;
}
public static AnalysisEngine defaultAnalysisEngine() {
try {
return createEngine(createEngineDescription(SentenceAnnotator.getDescription(),
TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"),
StemmerAnnotator.getDescription("English")));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Tokenizer create(String toTokenize) {
PosUimaTokenizer t = new PosUimaTokenizer(toTokenize, tokenizer, allowedPoSTags);
t.setTokenPreProcessor(tokenPreProcess);
return t;
}
@Override
public Tokenizer create(InputStream toTokenize) {
throw new UnsupportedOperationException();
}
@Override
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
this.tokenPreProcess = preProcessor;
}
}

View File

@ -1,62 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.InputStream;
/**
* Generates a tokenizer for a given string
* @author Adam Gibson
*
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface TokenizerFactory {
/**
* The tokenizer to createComplex
* @param toTokenize the string to createComplex the tokenizer with
* @return the new tokenizer
*/
Tokenizer create(String toTokenize);
/**
* Create a tokenizer based on an input stream
* @param toTokenize
* @return
*/
Tokenizer create(InputStream toTokenize);
/**
* Sets a token pre processor to be used
* with every tokenizer
* @param preProcessor the token pre processor to use
*/
void setTokenPreProcessor(TokenPreProcess preProcessor);
}

View File

@ -1,138 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.resource.ResourceInitializationException;
import org.datavec.nlp.annotator.SentenceAnnotator;
import org.datavec.nlp.annotator.TokenizerAnnotator;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizer.UimaTokenizer;
import org.datavec.nlp.uima.UimaResource;
import java.io.InputStream;
/**
* Uses a uima {@link AnalysisEngine} to
* tokenize text.
*
*
* @author Adam Gibson
*
*/
public class UimaTokenizerFactory implements TokenizerFactory {
private UimaResource uimaResource;
private boolean checkForLabel;
private static AnalysisEngine defaultAnalysisEngine;
private TokenPreProcess preProcess;
public UimaTokenizerFactory() throws ResourceInitializationException {
this(defaultAnalysisEngine(), true);
}
public UimaTokenizerFactory(UimaResource resource) {
this(resource, true);
}
public UimaTokenizerFactory(AnalysisEngine tokenizer) {
this(tokenizer, true);
}
public UimaTokenizerFactory(UimaResource resource, boolean checkForLabel) {
this.uimaResource = resource;
this.checkForLabel = checkForLabel;
}
public UimaTokenizerFactory(boolean checkForLabel) throws ResourceInitializationException {
this(defaultAnalysisEngine(), checkForLabel);
}
public UimaTokenizerFactory(AnalysisEngine tokenizer, boolean checkForLabel) {
super();
this.checkForLabel = checkForLabel;
try {
this.uimaResource = new UimaResource(tokenizer);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Tokenizer create(String toTokenize) {
if (toTokenize == null || toTokenize.isEmpty())
throw new IllegalArgumentException("Unable to proceed; on sentence to tokenize");
Tokenizer ret = new UimaTokenizer(toTokenize, uimaResource, checkForLabel);
ret.setTokenPreProcessor(preProcess);
return ret;
}
public UimaResource getUimaResource() {
return uimaResource;
}
/**
* Creates a tokenization,/stemming pipeline
* @return a tokenization/stemming pipeline
*/
public static AnalysisEngine defaultAnalysisEngine() {
try {
if (defaultAnalysisEngine == null)
defaultAnalysisEngine = AnalysisEngineFactory.createEngine(
AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.getDescription(),
TokenizerAnnotator.getDescription()));
return defaultAnalysisEngine;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Tokenizer create(InputStream toTokenize) {
throw new UnsupportedOperationException();
}
@Override
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
this.preProcess = preProcessor;
}
}

View File

@ -1,69 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import org.datavec.api.transform.Transform;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.util.List;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface BagOfWordsTransform extends Transform {
/**
* The output shape of the transform (usually 1 x number of words)
* @return
*/
long[] outputShape();
/**
* The vocab words in the transform.
* This is the words that were accumulated
* when building a vocabulary.
* (This is generally associated with some form of
* mininmum words frequency scanning to build a vocab
* you then map on to a list of vocab words as a list)
* @return the vocab words for the transform
*/
List<String> vocabWords();
/**
* Transform for a list of tokens
* that are objects. This is to allow loose
* typing for tokens that are unique (non string)
* @param tokens the token objects to transform
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
*/
INDArray transformFromObject(List<List<Object>> tokens);
/**
* Transform for a list of tokens
* that are {@link Writable} (Generally {@link org.datavec.api.writable.Text}
* @param tokens the token objects to transform
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
*/
INDArray transformFrom(List<List<Writable>> tokens);
}

View File

@ -1,24 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
public class BaseWordMapTransform {
}

View File

@ -1,146 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@Data
@EqualsAndHashCode(callSuper = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonIgnoreProperties({"gazeteer"})
public class GazeteerTransform extends BaseColumnTransform implements BagOfWordsTransform {
private String newColumnName;
private List<String> wordList;
private Set<String> gazeteer;
@JsonCreator
public GazeteerTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("newColumnName")String newColumnName,
@JsonProperty("wordList") List<String> wordList) {
super(columnName);
this.newColumnName = newColumnName;
this.wordList = wordList;
this.gazeteer = new HashSet<>(wordList);
}
@Override
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
return new NDArrayMetaData(newName,new long[]{wordList.size()});
}
@Override
public Writable map(Writable columnWritable) {
throw new UnsupportedOperationException();
}
@Override
public Object mapSequence(Object sequence) {
List<List<Object>> sequenceInput = (List<List<Object>>) sequence;
INDArray ret = Nd4j.create(DataType.FLOAT, wordList.size());
for(List<Object> list : sequenceInput) {
for(Object token : list) {
String s = token.toString();
if(gazeteer.contains(s)) {
ret.putScalar(wordList.indexOf(s),1);
}
}
}
return ret;
}
@Override
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
INDArray arr = (INDArray) mapSequence((Object) sequence);
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(arr)));
}
@Override
public String toString() {
return newColumnName;
}
@Override
public Object map(Object input) {
return gazeteer.contains(input.toString());
}
@Override
public String outputColumnName() {
return newColumnName;
}
@Override
public String[] outputColumnNames() {
return new String[]{newColumnName};
}
@Override
public String[] columnNames() {
return new String[]{columnName()};
}
@Override
public String columnName() {
return columnName;
}
@Override
public long[] outputShape() {
return new long[]{wordList.size()};
}
@Override
public List<String> vocabWords() {
return wordList;
}
@Override
public INDArray transformFromObject(List<List<Object>> tokens) {
return (INDArray) mapSequence(tokens);
}
@Override
public INDArray transformFrom(List<List<Writable>> tokens) {
return (INDArray) mapSequence((Object) tokens);
}
}

View File

@ -1,150 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.list.NDArrayList;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.List;
public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform {
private BagOfWordsTransform[] transforms;
private String newColumnName;
private List<String> vocabWords;
/**
*
* @param columnName
* @param transforms
* @param newColumnName
*/
@JsonCreator
public MultiNlpTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("transforms") BagOfWordsTransform[] transforms,
@JsonProperty("newColumnName") String newColumnName) {
super(columnName);
this.transforms = transforms;
this.vocabWords = transforms[0].vocabWords();
if(transforms.length > 1) {
for(int i = 1; i < transforms.length; i++) {
if(!transforms[i].vocabWords().equals(vocabWords)) {
throw new IllegalArgumentException("Vocab words not consistent across transforms!");
}
}
}
this.newColumnName = newColumnName;
}
@Override
public Object mapSequence(Object sequence) {
NDArrayList ndArrayList = new NDArrayList();
for(BagOfWordsTransform bagofWordsTransform : transforms) {
ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List<List<Object>>) sequence)));
}
return ndArrayList.array();
}
@Override
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(transformFrom(sequence))));
}
@Override
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
return new NDArrayMetaData(newName,outputShape());
}
@Override
public Writable map(Writable columnWritable) {
throw new UnsupportedOperationException("Only able to add for time series");
}
@Override
public String toString() {
return newColumnName;
}
@Override
public Object map(Object input) {
throw new UnsupportedOperationException("Only able to add for time series");
}
@Override
public long[] outputShape() {
long[] ret = new long[transforms[0].outputShape().length];
int validatedRank = transforms[0].outputShape().length;
for(int i = 1; i < transforms.length; i++) {
if(transforms[i].outputShape().length != validatedRank) {
throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank);
}
}
for(int i = 0; i < transforms.length; i++) {
for(int j = 0; j < validatedRank; j++)
ret[j] += transforms[i].outputShape()[j];
}
return ret;
}
@Override
public List<String> vocabWords() {
return vocabWords;
}
@Override
public INDArray transformFromObject(List<List<Object>> tokens) {
NDArrayList ndArrayList = new NDArrayList();
for(BagOfWordsTransform bagofWordsTransform : transforms) {
INDArray arr2 = bagofWordsTransform.transformFromObject(tokens);
arr2 = arr2.reshape(arr2.length());
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
ndArrayList.addAll(newList); }
return ndArrayList.array();
}
@Override
public INDArray transformFrom(List<List<Writable>> tokens) {
NDArrayList ndArrayList = new NDArrayList();
for(BagOfWordsTransform bagofWordsTransform : transforms) {
INDArray arr2 = bagofWordsTransform.transformFrom(tokens);
arr2 = arr2.reshape(arr2.length());
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
ndArrayList.addAll(newList);
}
return ndArrayList.array();
}
}

View File

@ -1,226 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.util.MathUtils;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = true, exclude = {"tokenizerFactory"})
@JsonInclude(JsonInclude.Include.NON_NULL)
public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform {
private String newColumName;
private Map<String,Integer> wordIndexMap;
private Map<String,Double> weightMap;
private boolean exceptionOnUnknown;
private String tokenizerFactoryClass;
private String preprocessorClass;
private TokenizerFactory tokenizerFactory;
@JsonCreator
public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("newColumnName") String newColumnName,
@JsonProperty("wordIndexMap") Map<String,Integer> wordIndexMap,
@JsonProperty("idfMap") Map<String,Double> idfMap,
@JsonProperty("exceptionOnUnknown") boolean exceptionOnUnknown,
@JsonProperty("tokenizerFactoryClass") String tokenizerFactoryClass,
@JsonProperty("preprocessorClass") String preprocessorClass) {
super(columnName);
this.newColumName = newColumnName;
this.wordIndexMap = wordIndexMap;
this.exceptionOnUnknown = exceptionOnUnknown;
this.weightMap = idfMap;
this.tokenizerFactoryClass = tokenizerFactoryClass;
this.preprocessorClass = preprocessorClass;
if(this.tokenizerFactoryClass == null) {
this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName();
}
try {
tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance();
} catch (Exception e) {
throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
}
if(preprocessorClass != null){
try {
TokenPreProcess tpp = (TokenPreProcess) Class.forName(this.preprocessorClass).newInstance();
tokenizerFactory.setTokenPreProcessor(tpp);
} catch (Exception e){
throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
}
}
}
@Override
public List<Writable> map(List<Writable> writables) {
Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName));
List<Writable> ret = new ArrayList<>(writables);
ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString())));
return ret;
}
@Override
public Object map(Object input) {
return convert(input.toString());
}
@Override
public Object mapSequence(Object sequence) {
return convert(sequence.toString());
}
@Override
public Schema transform(Schema inputSchema) {
Schema.Builder newSchema = new Schema.Builder();
for(int i = 0; i < inputSchema.numColumns(); i++) {
if(inputSchema.getName(i).equals(this.columnName)) {
newSchema.addColumnNDArray(newColumName,new long[]{1,wordIndexMap.size()});
}
else {
newSchema.addColumn(inputSchema.getMetaData(i));
}
}
return newSchema.build();
}
/**
* Convert the given text
* in to an {@link INDArray}
* using the {@link TokenizerFactory}
* specified in the constructor.
* @param text the text to transform
* @return the created {@link INDArray}
* based on the {@link #wordIndexMap} for the column indices
* of the word.
*/
public INDArray convert(String text) {
Tokenizer tokenizer = tokenizerFactory.create(text);
List<String> tokens = tokenizer.getTokens();
INDArray create = Nd4j.create(1,wordIndexMap.size());
Counter<String> tokenizedCounter = new Counter<>();
for(int i = 0; i < tokens.size(); i++) {
tokenizedCounter.incrementCount(tokens.get(i),1.0);
}
for(int i = 0; i < tokens.size(); i++) {
if(wordIndexMap.containsKey(tokens.get(i))) {
int idx = wordIndexMap.get(tokens.get(i));
int count = (int) tokenizedCounter.getCount(tokens.get(i));
double weight = tfidfWord(tokens.get(i),count,tokens.size());
create.putScalar(idx,weight);
}
}
return create;
}
/**
* Calculate the tifdf for a word
* given the word, word count, and document length
* @param word the word to calculate
* @param wordCount the word frequency
* @param documentLength the number of words in the document
* @return the tfidf weight for a given word
*/
public double tfidfWord(String word, long wordCount, long documentLength) {
double tf = tfForWord(wordCount, documentLength);
double idf = idfForWord(word);
return MathUtils.tfidf(tf, idf);
}
/**
* Calculate the weight term frequency for a given
* word normalized by the dcoument length
* @param wordCount the word frequency
* @param documentLength the number of words in the edocument
* @return
*/
private double tfForWord(long wordCount, long documentLength) {
return wordCount;
}
private double idfForWord(String word) {
if(weightMap.containsKey(word))
return weightMap.get(word);
return 0;
}
@Override
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
return new NDArrayMetaData(outputColumnName(),new long[]{1,wordIndexMap.size()});
}
@Override
public String outputColumnName() {
return newColumName;
}
@Override
public String[] outputColumnNames() {
return new String[]{newColumName};
}
@Override
public String[] columnNames() {
return new String[]{columnName()};
}
@Override
public String columnName() {
return columnName;
}
@Override
public Writable map(Writable columnWritable) {
return new NDArrayWritable(convert(columnWritable.toString()));
}
}

View File

@ -1,107 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.uima;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.CasPool;
public class UimaResource {
private AnalysisEngine analysisEngine;
private CasPool casPool;
public UimaResource(AnalysisEngine analysisEngine) throws ResourceInitializationException {
this.analysisEngine = analysisEngine;
this.casPool = new CasPool(Runtime.getRuntime().availableProcessors() * 10, analysisEngine);
}
public UimaResource(AnalysisEngine analysisEngine, CasPool casPool) {
this.analysisEngine = analysisEngine;
this.casPool = casPool;
}
public AnalysisEngine getAnalysisEngine() {
return analysisEngine;
}
public void setAnalysisEngine(AnalysisEngine analysisEngine) {
this.analysisEngine = analysisEngine;
}
public CasPool getCasPool() {
return casPool;
}
public void setCasPool(CasPool casPool) {
this.casPool = casPool;
}
/**
* Use the given analysis engine and process the given text
* You must release the return cas yourself
* @param text the text to rpocess
* @return the processed cas
*/
public CAS process(String text) {
CAS cas = retrieve();
cas.setDocumentText(text);
try {
analysisEngine.process(cas);
} catch (AnalysisEngineProcessException e) {
if (text != null && !text.isEmpty())
return process(text);
throw new RuntimeException(e);
}
return cas;
}
public CAS retrieve() {
CAS ret = casPool.getCas();
try {
return ret == null ? analysisEngine.newCAS() : ret;
} catch (ResourceInitializationException e) {
throw new RuntimeException(e);
}
}
public void release(CAS cas) {
casPool.releaseCas(cas);
}
}

Some files were not shown because too many files have changed in this diff Show More