Merge branch 'master' into sa_tvm
commit
ad12d2148d
|
@ -0,0 +1,11 @@
|
|||
name: Download dl4j test resources
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Initial install
|
||||
shell: bash
|
||||
run: |
|
||||
wget https://github.com/KonduitAI/dl4j-test-resources/archive/master.zip && unzip master.zip
|
||||
cd dl4j-test-resources-master
|
||||
mvn clean install -DskipTests
|
||||
echo "Extracted test resources"
|
|
@ -0,0 +1,12 @@
|
|||
name: Download dl4j test resources
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Initial install
|
||||
shell: cmd
|
||||
run: |
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
wget https://github.com/KonduitAI/dl4j-test-resources/archive/master.zip && unzip master.zip
|
||||
cd dl4j-test-resources-master
|
||||
mvn clean install -DskipTests
|
||||
echo "Extracted test resources"
|
|
@ -0,0 +1,12 @@
|
|||
name: Download dl4j test resources
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Initial install
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt install git gcc-8-aarch64-linux-gnu g++-8-aarch64-linux-gnu libc6-armel-cross libc6-dev-armel-cross binutils-arm-linux-gnueabi libncurses5-dev build-essential bison flex libssl-dev bc \
|
||||
gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf crossbuild-essential-arm64
|
||||
mkdir -p /opt/raspberrypi && \
|
||||
cd /opt/raspberrypi && \
|
||||
git clone git://github.com/raspberrypi/tools.git
|
|
@ -0,0 +1,16 @@
|
|||
name: Install protobuf linux
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install protobuf linux
|
||||
shell: bash
|
||||
run: |
|
||||
curl -fsSL https://github.com/google/protobuf/releases/download/v3.5.1/protobuf-cpp-3.5.1.tar.gz \
|
||||
| tar xz && \
|
||||
cd protobuf-3.5.1 && \
|
||||
./configure --prefix=/opt/protobuf && \
|
||||
make -j2 && \
|
||||
make install && \
|
||||
cd .. && \
|
||||
rm -rf protobuf-3.5.1
|
||||
echo "/opt/protobuf/bin" >> $GITHUB_PATH
|
|
@ -0,0 +1,10 @@
|
|||
name: Setup for msys2
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Initial install
|
||||
shell: cmd
|
||||
run: |
|
||||
C:\msys64\usr\bin\bash -lc "pacman -S --needed --noconfirm base-devel git tar pkg-config unzip p7zip zip autoconf autoconf-archive automake make patch gnupg"
|
||||
C:\msys64\usr\bin\bash -lc "pacman -S --needed --noconfirm mingw-w64-x86_64-nasm mingw-w64-x86_64-toolchain mingw-w64-x86_64-libtool mingw-w64-x86_64-gcc mingw-w64-i686-gcc mingw-w64-x86_64-gcc-fortran mingw-w64-i686-gcc-fortran mingw-w64-x86_64-libwinpthread-git mingw-w64-i686-libwinpthread-git mingw-w64-x86_64-SDL mingw-w64-i686-SDL mingw-w64-x86_64-ragel"
|
||||
echo "C:\msys64\usr\bin" >> $GITHUB_PATH
|
|
@ -0,0 +1,9 @@
|
|||
name: Publish to github packages
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Publish to GitHub Packages
|
||||
run: mvn -Pgithub --batch-mode deploy
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
@ -0,0 +1,48 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
android-x86_64:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- uses: AutoModality/action-clean@v1
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: nttld/setup-ndk@v1
|
||||
id: setup-ndk
|
||||
with:
|
||||
ndk-version: r18b
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Build on linux-x86_64
|
||||
env:
|
||||
ANDROID_NDK: ${{ steps.setup-ndk.outputs.ndk-path }}
|
||||
LIBND4J_HOME: "${GITHUB_WORKSPACE}/libnd4j"
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
echo "Verifying programs on path. Path is $PATH"
|
||||
echo "Path post update is $PATH. Maven is at `which mvn` cmake is at `which cmake` protoc is at `which protoc`"
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
clang --version
|
||||
mvn -X -Dorg.bytedeco.javacpp.logger.debug=true -Possrh -pl ":nd4j-native,:libnd4j" --also-make \
|
||||
-Djavacpp.platform=android-x86_64 \
|
||||
-Dlibnd4j.platform=android-x86_64 -Dlibnd4j.chip=cpu \
|
||||
--batch-mode clean deploy -DskipTests
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
|
||||
android-arm32:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- uses: AutoModality/action-clean@v1
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Build on android-arm32
|
||||
shell: bash
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
DEPLOY: 1
|
||||
BUILD_USING_MAVEN: 1
|
||||
TARGET_OS: android
|
||||
CURRENT_TARGET: arm32
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
|
||||
android-arm64:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- uses: AutoModality/action-clean@v1
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
|
||||
- name: Build on android-arm64
|
||||
shell: bash
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
DEPLOY: 1
|
||||
BUILD_USING_MAVEN: 1
|
||||
TARGET_OS: android
|
||||
CURRENT_TARGET: arm64
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
|
||||
linux-arm32:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- uses: AutoModality/action-clean@v1
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Build on linux-arm32
|
||||
shell: bash
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
DEPLOY: 1
|
||||
BUILD_USING_MAVEN: 1
|
||||
TARGET_OS: linux
|
||||
CURRENT_TARGET: arm32
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh
|
||||
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
|
||||
linux-arm64:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- uses: AutoModality/action-clean@v1
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Build on linux-arm64
|
||||
shell: bash
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
DEPLOY: 1
|
||||
BUILD_USING_MAVEN: 1
|
||||
TARGET_OS: linux
|
||||
CURRENT_TARGET: arm64
|
||||
PUBLISH_TO: ossrh
|
||||
run: |
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
${GITHUB_WORKSPACE}/libnd4j/pi_build.sh
|
||||
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
|
||||
linux-x86_64-cuda_11-0:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- name: Maximize build space
|
||||
uses: easimon/maximize-build-space@master
|
||||
with:
|
||||
root-reserve-mb: 512
|
||||
swap-size-mb: 8192
|
||||
remove-dotnet: 'true'
|
||||
remove-haskell: 'true'
|
||||
- uses: actions/checkout@v2
|
||||
- uses: konduitai/cuda-install/.github/actions/install-cuda-ubuntu@master
|
||||
env:
|
||||
cuda: 11.0.167
|
||||
GCC: 9
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
|
||||
- name: Build cuda
|
||||
shell: bash
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
export PATH="/usr/local/cuda-11.0/bin:$PATH"
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
nvcc --version
|
||||
sudo apt-get autoremove
|
||||
sudo apt-get clean
|
||||
mvn -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -Dlibnd4j.chip=cuda -pl ":nd4j-cuda-11.0,:deeplearning4j-cuda-11.0,:libnd4j" --also-make -Pcuda clean --batch-mode deploy -DskipTests
|
||||
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
linux-x86_64-cuda-11-2:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- name: Maximize build space
|
||||
uses: easimon/maximize-build-space@master
|
||||
with:
|
||||
root-reserve-mb: 512
|
||||
swap-size-mb: 8192
|
||||
remove-dotnet: 'true'
|
||||
remove-haskell: 'true'
|
||||
- uses: actions/checkout@v2
|
||||
- uses: konduitai/cuda-install/.github/actions/install-cuda-ubuntu@master
|
||||
env:
|
||||
cuda: 11.2.1_461
|
||||
GCC: 9
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Run cuda compilation on linux-x86_64
|
||||
shell: bash
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
export PATH="/usr/local/cuda-11.2/bin:$PATH"
|
||||
nvcc --version
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
sudo apt-get autoremove
|
||||
sudo apt-get clean
|
||||
bash ./change-cuda-versions.sh 11.2
|
||||
mvn -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -pl ":nd4j-cuda-11.2,:deeplearning4j-cuda-11.2,:libnd4j" --also-make -Dlibnd4j.chip=cuda --batch-mode deploy -DskipTests
|
|
@ -0,0 +1,41 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
#Note: no -pl here because we publish everything from this branch and use this as the basis for all uploads.
|
||||
linux-x86_64:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Build on linux-x86_64
|
||||
shell: bash
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
sudo apt-get autoremove
|
||||
sudo apt-get clean
|
||||
mvn -X -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.chip=cpu -Pcpu --batch-mode deploy -DskipTests
|
||||
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
mac-x86_64:
|
||||
runs-on: macos-10.15
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Build and install
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||
mvn -Possrh -Djavacpp.platform=macosx-x86_64 -Djavacpp.platform=macosx-x86_64 -pl ":nd4j-native,:libnd4j" --also-make -Dlibnd4j.platform=macosx-x86_64 -Dlibnd4j.chip=cpu clean --batch-mode deploy -DskipTests
|
||||
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
windows-x86_64-cuda-11-0:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- name: Use existing msys2 to setup environment
|
||||
uses: ./.github/actions/msys2-base-setup
|
||||
- uses: konduitai/cuda-install/.github/actions/install-cuda-windows@master
|
||||
env:
|
||||
cuda: 11.0.167
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Run windows build
|
||||
shell: cmd
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
set MSYSTEM=MINGW64
|
||||
set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0"
|
||||
which cmake
|
||||
dir "%CUDA_PATH%"
|
||||
dir "%CUDA_PATH%\lib"
|
||||
set "PATH=C:\msys64\usr\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\lib\x64;%PATH%"
|
||||
echo "Running cuda build"
|
||||
mvn -Possrh -Djavacpp.platform=windows-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -Djavacpp.platform=windows-x86_64 -pl ":nd4j-cuda-11.0,:deeplearning4j-cuda-11.0,:libnd4j" --also-make -Dlibnd4j.platform=windows-x86_64 -Pcuda -Dlibnd4j.chip=cuda -Pcuda clean --batch-mode deploy -DskipTests
|
||||
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
windows-x86_64-cuda-11-2:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- name: Use existing msys2 to setup environment
|
||||
uses: ./.github/actions/msys2-base-setup
|
||||
- uses: konduitai/cuda-install/.github/actions/install-cuda-windows@master
|
||||
env:
|
||||
cuda: 11.2.1
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Run cuda build
|
||||
shell: cmd
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
set MSYSTEM=MINGW64
|
||||
set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2"
|
||||
dir "%CUDA_PATH%"
|
||||
dir "%CUDA_PATH%\lib"
|
||||
which cmake
|
||||
set "PATH=C:\msys64\usr\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\lib\x64;%PATH%"
|
||||
echo "Running cuda build"
|
||||
bash ./change-cuda-versions.sh 11.2
|
||||
sudo apt-get autoremove
|
||||
sudo apt-get clean
|
||||
mvn -Possrh -Djavacpp.platform=linux-x86_64 -Dlibnd4j.compute="5.0 5.2 5.3 6.0 6.2 8.0" -Djavacpp.platform=windows-x86_64 -pl ":nd4j-cuda-11.2,:libnd4j,:deeplearning4j-cuda-11.2" --also-make -Dlibnd4j.platform=windows-x86_64 -Pcuda -Dlibnd4j.chip=cuda -Pcuda clean --batch-mode deploy -DskipTests
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
jobs:
|
||||
|
||||
windows-x86_64:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/msys2-base-setup
|
||||
- name: Set up Java for publishing to GitHub Packages
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 1.8
|
||||
server-id: sonatype-nexus-snapshots
|
||||
server-username: MAVEN_USERNAME
|
||||
server-password: MAVEN_PASSWORD
|
||||
gpg-private-key: ${{ secrets.SONATYPE_GPG_KEY }}
|
||||
gpg-passphrase: MAVEN_GPG_PASSPHRASE
|
||||
- name: Run windows cpu build
|
||||
shell: cmd
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PUBLISH_TO: ossrh
|
||||
MAVEN_USERNAME: ${{ secrets.SONATYPE_USER_1 }}
|
||||
MAVEN_PASSWORD: ${{ secrets.SONATYPE_USER1_PASS }}
|
||||
MAVEN_GPG_PASSPHRASE: ${{ secrets.PACKAGES_GPG_PASS }}
|
||||
run: |
|
||||
set MSYSTEM=MINGW64
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
mvn -Possrh -Djavacpp.platform=windows-x86_64 -pl ":nd4j-native,:libnd4j" --also-make -Dlibnd4j.platform=windows-x86_64 -Dlibnd4j.chip=cpu deploy -DskipTests
|
|
@ -0,0 +1,52 @@
|
|||
on:
|
||||
push:
|
||||
jobs:
|
||||
linux-x86_64:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- uses: ./.github/actions/download-dl4j-test-resources-linux
|
||||
- name: Run tests on linux-x86_64
|
||||
shell: bash
|
||||
run: |
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
windows-x86_64:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/msys2-base-setup
|
||||
- uses: ./.github/actions/download-dl4j-test-resources-windows
|
||||
- name: Run tests
|
||||
shell: cmd
|
||||
run: |
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
||||
|
||||
mac-x86_64:
|
||||
runs-on: macos-10.15
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/download-dl4j-test-resources-linux
|
||||
- name: Install and run tests
|
||||
shell: bash
|
||||
env:
|
||||
VERBOSE: 1
|
||||
run: |
|
||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
on:
|
||||
push:
|
||||
jobs:
|
||||
linux-x86_64:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Cancel Previous Runs
|
||||
uses: styfle/cancel-workflow-action@0.8.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/install-protobuf-linux
|
||||
- uses: ./.github/actions/download-dl4j-test-resources-linux
|
||||
- name: Run tests on linux-x86_64
|
||||
shell: bash
|
||||
run: |
|
||||
mvn --version
|
||||
cmake --version
|
||||
protoc --version
|
||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
windows-x86_64:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/msys2-base-setup
|
||||
- uses: ./.github/actions/download-dl4j-test-resources-windows
|
||||
- name: Run tests
|
||||
shell: cmd
|
||||
run: |
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
||||
|
||||
mac-x86_64:
|
||||
runs-on: macos-10.15
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/download-dl4j-test-resources-linux
|
||||
- name: Install and run tests
|
||||
shell: bash
|
||||
env:
|
||||
VERBOSE: 1
|
||||
run: |
|
||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
|
|
@ -79,3 +79,5 @@ libnd4j/cmake*
|
|||
|
||||
#vim
|
||||
*.swp
|
||||
|
||||
*.dll
|
|
@ -83,4 +83,8 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return Long.MAX_VALUE;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -60,9 +61,10 @@ public class WritableTest extends BaseND4JTest {
|
|||
public void testBytesWritableIndexing() {
|
||||
byte[] doubleWrite = new byte[16];
|
||||
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
|
||||
Buffer buffer = (Buffer) wrapped;
|
||||
wrapped.putDouble(1.0);
|
||||
wrapped.putDouble(2.0);
|
||||
wrapped.rewind();
|
||||
buffer.rewind();
|
||||
BytesWritable byteWritable = new BytesWritable(doubleWrite);
|
||||
assertEquals(2,byteWritable.getDouble(1),1e-1);
|
||||
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>datavec-data-audio</artifactId>
|
||||
|
||||
<name>datavec-data-audio</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>javacpp</artifactId>
|
||||
<version>${javacpp.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>javacv</artifactId>
|
||||
<version>${javacv.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.wendykierp</groupId>
|
||||
<artifactId>JTransforms</artifactId>
|
||||
<version>${jtransforms.version}</version>
|
||||
<classifier>with-dependencies</classifier>
|
||||
</dependency>
|
||||
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
|
||||
<!--
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>ffmpeg-platform</artifactId>
|
||||
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
|
||||
</dependency>
|
||||
-->
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,329 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio;
|
||||
|
||||
|
||||
import org.datavec.audio.extension.NormalizedSampleAmplitudes;
|
||||
import org.datavec.audio.extension.Spectrogram;
|
||||
import org.datavec.audio.fingerprint.FingerprintManager;
|
||||
import org.datavec.audio.fingerprint.FingerprintSimilarity;
|
||||
import org.datavec.audio.fingerprint.FingerprintSimilarityComputer;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.Serializable;
|
||||
|
||||
public class Wave implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
private WaveHeader waveHeader;
|
||||
private byte[] data; // little endian
|
||||
private byte[] fingerprint;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
*/
|
||||
public Wave() {
|
||||
this.waveHeader = new WaveHeader();
|
||||
this.data = new byte[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param filename
|
||||
* Wave file
|
||||
*/
|
||||
public Wave(String filename) {
|
||||
try {
|
||||
InputStream inputStream = new FileInputStream(filename);
|
||||
initWaveWithInputStream(inputStream);
|
||||
inputStream.close();
|
||||
} catch (IOException e) {
|
||||
System.out.println(e.toString());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param inputStream
|
||||
* Wave file input stream
|
||||
*/
|
||||
public Wave(InputStream inputStream) {
|
||||
initWaveWithInputStream(inputStream);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param waveHeader
|
||||
* @param data
|
||||
*/
|
||||
public Wave(WaveHeader waveHeader, byte[] data) {
|
||||
this.waveHeader = waveHeader;
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
private void initWaveWithInputStream(InputStream inputStream) {
|
||||
// reads the first 44 bytes for header
|
||||
waveHeader = new WaveHeader(inputStream);
|
||||
|
||||
if (waveHeader.isValid()) {
|
||||
// load data
|
||||
try {
|
||||
data = new byte[inputStream.available()];
|
||||
inputStream.read(data);
|
||||
} catch (IOException e) {
|
||||
System.err.println(e.toString());
|
||||
}
|
||||
// end load data
|
||||
} else {
|
||||
System.err.println("Invalid Wave Header");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim the wave data
|
||||
*
|
||||
* @param leftTrimNumberOfSample
|
||||
* Number of sample trimmed from beginning
|
||||
* @param rightTrimNumberOfSample
|
||||
* Number of sample trimmed from ending
|
||||
*/
|
||||
public void trim(int leftTrimNumberOfSample, int rightTrimNumberOfSample) {
|
||||
|
||||
long chunkSize = waveHeader.getChunkSize();
|
||||
long subChunk2Size = waveHeader.getSubChunk2Size();
|
||||
|
||||
long totalTrimmed = leftTrimNumberOfSample + rightTrimNumberOfSample;
|
||||
|
||||
if (totalTrimmed > subChunk2Size) {
|
||||
leftTrimNumberOfSample = (int) subChunk2Size;
|
||||
}
|
||||
|
||||
// update wav info
|
||||
chunkSize -= totalTrimmed;
|
||||
subChunk2Size -= totalTrimmed;
|
||||
|
||||
if (chunkSize >= 0 && subChunk2Size >= 0) {
|
||||
waveHeader.setChunkSize(chunkSize);
|
||||
waveHeader.setSubChunk2Size(subChunk2Size);
|
||||
|
||||
byte[] trimmedData = new byte[(int) subChunk2Size];
|
||||
System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size);
|
||||
data = trimmedData;
|
||||
} else {
|
||||
System.err.println("Trim error: Negative length");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim the wave data from beginning
|
||||
*
|
||||
* @param numberOfSample
|
||||
* numberOfSample trimmed from beginning
|
||||
*/
|
||||
public void leftTrim(int numberOfSample) {
|
||||
trim(numberOfSample, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim the wave data from ending
|
||||
*
|
||||
* @param numberOfSample
|
||||
* numberOfSample trimmed from ending
|
||||
*/
|
||||
public void rightTrim(int numberOfSample) {
|
||||
trim(0, numberOfSample);
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim the wave data
|
||||
*
|
||||
* @param leftTrimSecond
|
||||
* Seconds trimmed from beginning
|
||||
* @param rightTrimSecond
|
||||
* Seconds trimmed from ending
|
||||
*/
|
||||
public void trim(double leftTrimSecond, double rightTrimSecond) {
|
||||
|
||||
int sampleRate = waveHeader.getSampleRate();
|
||||
int bitsPerSample = waveHeader.getBitsPerSample();
|
||||
int channels = waveHeader.getChannels();
|
||||
|
||||
int leftTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * leftTrimSecond);
|
||||
int rightTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * rightTrimSecond);
|
||||
|
||||
trim(leftTrimNumberOfSample, rightTrimNumberOfSample);
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim the wave data from beginning
|
||||
*
|
||||
* @param second
|
||||
* Seconds trimmed from beginning
|
||||
*/
|
||||
public void leftTrim(double second) {
|
||||
trim(second, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Trim the wave data from ending
|
||||
*
|
||||
* @param second
|
||||
* Seconds trimmed from ending
|
||||
*/
|
||||
public void rightTrim(double second) {
|
||||
trim(0, second);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the wave header
|
||||
*
|
||||
* @return waveHeader
|
||||
*/
|
||||
public WaveHeader getWaveHeader() {
|
||||
return waveHeader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the wave spectrogram
|
||||
*
|
||||
* @return spectrogram
|
||||
*/
|
||||
public Spectrogram getSpectrogram() {
|
||||
return new Spectrogram(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the wave spectrogram
|
||||
*
|
||||
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
|
||||
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
|
||||
*
|
||||
* @return spectrogram
|
||||
*/
|
||||
public Spectrogram getSpectrogram(int fftSampleSize, int overlapFactor) {
|
||||
return new Spectrogram(this, fftSampleSize, overlapFactor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the wave data in bytes
|
||||
*
|
||||
* @return wave data
|
||||
*/
|
||||
public byte[] getBytes() {
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Data byte size of the wave excluding header size
|
||||
*
|
||||
* @return byte size of the wave
|
||||
*/
|
||||
public int size() {
|
||||
return data.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Length of the wave in second
|
||||
*
|
||||
* @return length in second
|
||||
*/
|
||||
public float length() {
|
||||
return (float) waveHeader.getSubChunk2Size() / waveHeader.getByteRate();
|
||||
}
|
||||
|
||||
/**
|
||||
* Timestamp of the wave length
|
||||
*
|
||||
* @return timestamp
|
||||
*/
|
||||
public String timestamp() {
|
||||
float totalSeconds = this.length();
|
||||
float second = totalSeconds % 60;
|
||||
int minute = (int) totalSeconds / 60 % 60;
|
||||
int hour = (int) (totalSeconds / 3600);
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (hour > 0) {
|
||||
sb.append(hour + ":");
|
||||
}
|
||||
if (minute > 0) {
|
||||
sb.append(minute + ":");
|
||||
}
|
||||
sb.append(second);
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the amplitudes of the wave samples (depends on the header)
|
||||
*
|
||||
* @return amplitudes array (signed 16-bit)
|
||||
*/
|
||||
public short[] getSampleAmplitudes() {
|
||||
int bytePerSample = waveHeader.getBitsPerSample() / 8;
|
||||
int numSamples = data.length / bytePerSample;
|
||||
short[] amplitudes = new short[numSamples];
|
||||
|
||||
int pointer = 0;
|
||||
for (int i = 0; i < numSamples; i++) {
|
||||
short amplitude = 0;
|
||||
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
|
||||
// little endian
|
||||
amplitude |= (short) ((data[pointer++] & 0xFF) << (byteNumber * 8));
|
||||
}
|
||||
amplitudes[i] = amplitude;
|
||||
}
|
||||
|
||||
return amplitudes;
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder(waveHeader.toString());
|
||||
sb.append("\n");
|
||||
sb.append("length: " + timestamp());
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public double[] getNormalizedAmplitudes() {
|
||||
NormalizedSampleAmplitudes amplitudes = new NormalizedSampleAmplitudes(this);
|
||||
return amplitudes.getNormalizedAmplitudes();
|
||||
}
|
||||
|
||||
public byte[] getFingerprint() {
|
||||
if (fingerprint == null) {
|
||||
FingerprintManager fingerprintManager = new FingerprintManager();
|
||||
fingerprint = fingerprintManager.extractFingerprint(this);
|
||||
}
|
||||
return fingerprint;
|
||||
}
|
||||
|
||||
public FingerprintSimilarity getFingerprintSimilarity(Wave wave) {
|
||||
FingerprintSimilarityComputer fingerprintSimilarityComputer =
|
||||
new FingerprintSimilarityComputer(this.getFingerprint(), wave.getFingerprint());
|
||||
return fingerprintSimilarityComputer.getFingerprintsSimilarity();
|
||||
}
|
||||
}
|
|
@ -1,98 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
@Slf4j
|
||||
public class WaveFileManager {
|
||||
|
||||
private Wave wave;
|
||||
|
||||
public WaveFileManager() {
|
||||
wave = new Wave();
|
||||
}
|
||||
|
||||
public WaveFileManager(Wave wave) {
|
||||
setWave(wave);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the wave file
|
||||
*
|
||||
* @param filename
|
||||
* filename to be saved
|
||||
*
|
||||
* @see Wave file saved
|
||||
*/
|
||||
public void saveWaveAsFile(String filename) {
|
||||
|
||||
WaveHeader waveHeader = wave.getWaveHeader();
|
||||
|
||||
int byteRate = waveHeader.getByteRate();
|
||||
int audioFormat = waveHeader.getAudioFormat();
|
||||
int sampleRate = waveHeader.getSampleRate();
|
||||
int bitsPerSample = waveHeader.getBitsPerSample();
|
||||
int channels = waveHeader.getChannels();
|
||||
long chunkSize = waveHeader.getChunkSize();
|
||||
long subChunk1Size = waveHeader.getSubChunk1Size();
|
||||
long subChunk2Size = waveHeader.getSubChunk2Size();
|
||||
int blockAlign = waveHeader.getBlockAlign();
|
||||
|
||||
try {
|
||||
FileOutputStream fos = new FileOutputStream(filename);
|
||||
fos.write(WaveHeader.RIFF_HEADER.getBytes());
|
||||
// little endian
|
||||
fos.write(new byte[] {(byte) (chunkSize), (byte) (chunkSize >> 8), (byte) (chunkSize >> 16),
|
||||
(byte) (chunkSize >> 24)});
|
||||
fos.write(WaveHeader.WAVE_HEADER.getBytes());
|
||||
fos.write(WaveHeader.FMT_HEADER.getBytes());
|
||||
fos.write(new byte[] {(byte) (subChunk1Size), (byte) (subChunk1Size >> 8), (byte) (subChunk1Size >> 16),
|
||||
(byte) (subChunk1Size >> 24)});
|
||||
fos.write(new byte[] {(byte) (audioFormat), (byte) (audioFormat >> 8)});
|
||||
fos.write(new byte[] {(byte) (channels), (byte) (channels >> 8)});
|
||||
fos.write(new byte[] {(byte) (sampleRate), (byte) (sampleRate >> 8), (byte) (sampleRate >> 16),
|
||||
(byte) (sampleRate >> 24)});
|
||||
fos.write(new byte[] {(byte) (byteRate), (byte) (byteRate >> 8), (byte) (byteRate >> 16),
|
||||
(byte) (byteRate >> 24)});
|
||||
fos.write(new byte[] {(byte) (blockAlign), (byte) (blockAlign >> 8)});
|
||||
fos.write(new byte[] {(byte) (bitsPerSample), (byte) (bitsPerSample >> 8)});
|
||||
fos.write(WaveHeader.DATA_HEADER.getBytes());
|
||||
fos.write(new byte[] {(byte) (subChunk2Size), (byte) (subChunk2Size >> 8), (byte) (subChunk2Size >> 16),
|
||||
(byte) (subChunk2Size >> 24)});
|
||||
fos.write(wave.getBytes());
|
||||
fos.close();
|
||||
} catch (IOException e) {
|
||||
log.error("",e);
|
||||
}
|
||||
}
|
||||
|
||||
public Wave getWave() {
|
||||
return wave;
|
||||
}
|
||||
|
||||
public void setWave(Wave wave) {
|
||||
this.wave = wave;
|
||||
}
|
||||
}
|
|
@ -1,281 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
@Slf4j
|
||||
public class WaveHeader {
|
||||
|
||||
public static final String RIFF_HEADER = "RIFF";
|
||||
public static final String WAVE_HEADER = "WAVE";
|
||||
public static final String FMT_HEADER = "fmt ";
|
||||
public static final String DATA_HEADER = "data";
|
||||
public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header
|
||||
|
||||
private boolean valid;
|
||||
private String chunkId; // 4 bytes
|
||||
private long chunkSize; // unsigned 4 bytes, little endian
|
||||
private String format; // 4 bytes
|
||||
private String subChunk1Id; // 4 bytes
|
||||
private long subChunk1Size; // unsigned 4 bytes, little endian
|
||||
private int audioFormat; // unsigned 2 bytes, little endian
|
||||
private int channels; // unsigned 2 bytes, little endian
|
||||
private long sampleRate; // unsigned 4 bytes, little endian
|
||||
private long byteRate; // unsigned 4 bytes, little endian
|
||||
private int blockAlign; // unsigned 2 bytes, little endian
|
||||
private int bitsPerSample; // unsigned 2 bytes, little endian
|
||||
private String subChunk2Id; // 4 bytes
|
||||
private long subChunk2Size; // unsigned 4 bytes, little endian
|
||||
|
||||
public WaveHeader() {
|
||||
// init a 8k 16bit mono wav
|
||||
chunkSize = 36;
|
||||
subChunk1Size = 16;
|
||||
audioFormat = 1;
|
||||
channels = 1;
|
||||
sampleRate = 8000;
|
||||
byteRate = 16000;
|
||||
blockAlign = 2;
|
||||
bitsPerSample = 16;
|
||||
subChunk2Size = 0;
|
||||
valid = true;
|
||||
}
|
||||
|
||||
public WaveHeader(InputStream inputStream) {
|
||||
valid = loadHeader(inputStream);
|
||||
}
|
||||
|
||||
private boolean loadHeader(InputStream inputStream) {
|
||||
|
||||
byte[] headerBuffer = new byte[HEADER_BYTE_LENGTH];
|
||||
try {
|
||||
inputStream.read(headerBuffer);
|
||||
|
||||
// read header
|
||||
int pointer = 0;
|
||||
chunkId = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
|
||||
headerBuffer[pointer++]});
|
||||
// little endian
|
||||
chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
||||
| (long) (headerBuffer[pointer++] & 0xff << 24);
|
||||
format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
|
||||
headerBuffer[pointer++]});
|
||||
subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
|
||||
headerBuffer[pointer++], headerBuffer[pointer++]});
|
||||
subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
||||
audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
||||
channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
||||
sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
||||
byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
||||
blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
||||
bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
||||
subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
|
||||
headerBuffer[pointer++], headerBuffer[pointer++]});
|
||||
subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
||||
// end read header
|
||||
|
||||
// the inputStream should be closed outside this method
|
||||
|
||||
// dis.close();
|
||||
|
||||
} catch (IOException e) {
|
||||
log.error("",e);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (bitsPerSample != 8 && bitsPerSample != 16) {
|
||||
System.err.println("WaveHeader: only supports bitsPerSample 8 or 16");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check the format is support
|
||||
if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) {
|
||||
return true;
|
||||
} else {
|
||||
System.err.println("WaveHeader: Unsupported header format");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public boolean isValid() {
|
||||
return valid;
|
||||
}
|
||||
|
||||
public String getChunkId() {
|
||||
return chunkId;
|
||||
}
|
||||
|
||||
public long getChunkSize() {
|
||||
return chunkSize;
|
||||
}
|
||||
|
||||
public String getFormat() {
|
||||
return format;
|
||||
}
|
||||
|
||||
public String getSubChunk1Id() {
|
||||
return subChunk1Id;
|
||||
}
|
||||
|
||||
public long getSubChunk1Size() {
|
||||
return subChunk1Size;
|
||||
}
|
||||
|
||||
public int getAudioFormat() {
|
||||
return audioFormat;
|
||||
}
|
||||
|
||||
public int getChannels() {
|
||||
return channels;
|
||||
}
|
||||
|
||||
public int getSampleRate() {
|
||||
return (int) sampleRate;
|
||||
}
|
||||
|
||||
public int getByteRate() {
|
||||
return (int) byteRate;
|
||||
}
|
||||
|
||||
public int getBlockAlign() {
|
||||
return blockAlign;
|
||||
}
|
||||
|
||||
public int getBitsPerSample() {
|
||||
return bitsPerSample;
|
||||
}
|
||||
|
||||
public String getSubChunk2Id() {
|
||||
return subChunk2Id;
|
||||
}
|
||||
|
||||
public long getSubChunk2Size() {
|
||||
return subChunk2Size;
|
||||
}
|
||||
|
||||
public void setSampleRate(int sampleRate) {
|
||||
int newSubChunk2Size = (int) (this.subChunk2Size * sampleRate / this.sampleRate);
|
||||
// if num bytes for each sample is even, the size of newSubChunk2Size also needed to be in even number
|
||||
if ((bitsPerSample / 8) % 2 == 0) {
|
||||
if (newSubChunk2Size % 2 != 0) {
|
||||
newSubChunk2Size++;
|
||||
}
|
||||
}
|
||||
|
||||
this.sampleRate = sampleRate;
|
||||
this.byteRate = sampleRate * bitsPerSample / 8;
|
||||
this.chunkSize = newSubChunk2Size + 36;
|
||||
this.subChunk2Size = newSubChunk2Size;
|
||||
}
|
||||
|
||||
public void setChunkId(String chunkId) {
|
||||
this.chunkId = chunkId;
|
||||
}
|
||||
|
||||
public void setChunkSize(long chunkSize) {
|
||||
this.chunkSize = chunkSize;
|
||||
}
|
||||
|
||||
public void setFormat(String format) {
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
public void setSubChunk1Id(String subChunk1Id) {
|
||||
this.subChunk1Id = subChunk1Id;
|
||||
}
|
||||
|
||||
public void setSubChunk1Size(long subChunk1Size) {
|
||||
this.subChunk1Size = subChunk1Size;
|
||||
}
|
||||
|
||||
public void setAudioFormat(int audioFormat) {
|
||||
this.audioFormat = audioFormat;
|
||||
}
|
||||
|
||||
public void setChannels(int channels) {
|
||||
this.channels = channels;
|
||||
}
|
||||
|
||||
public void setByteRate(long byteRate) {
|
||||
this.byteRate = byteRate;
|
||||
}
|
||||
|
||||
public void setBlockAlign(int blockAlign) {
|
||||
this.blockAlign = blockAlign;
|
||||
}
|
||||
|
||||
public void setBitsPerSample(int bitsPerSample) {
|
||||
this.bitsPerSample = bitsPerSample;
|
||||
}
|
||||
|
||||
public void setSubChunk2Id(String subChunk2Id) {
|
||||
this.subChunk2Id = subChunk2Id;
|
||||
}
|
||||
|
||||
public void setSubChunk2Size(long subChunk2Size) {
|
||||
this.subChunk2Size = subChunk2Size;
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("chunkId: " + chunkId);
|
||||
sb.append("\n");
|
||||
sb.append("chunkSize: " + chunkSize);
|
||||
sb.append("\n");
|
||||
sb.append("format: " + format);
|
||||
sb.append("\n");
|
||||
sb.append("subChunk1Id: " + subChunk1Id);
|
||||
sb.append("\n");
|
||||
sb.append("subChunk1Size: " + subChunk1Size);
|
||||
sb.append("\n");
|
||||
sb.append("audioFormat: " + audioFormat);
|
||||
sb.append("\n");
|
||||
sb.append("channels: " + channels);
|
||||
sb.append("\n");
|
||||
sb.append("sampleRate: " + sampleRate);
|
||||
sb.append("\n");
|
||||
sb.append("byteRate: " + byteRate);
|
||||
sb.append("\n");
|
||||
sb.append("blockAlign: " + blockAlign);
|
||||
sb.append("\n");
|
||||
sb.append("bitsPerSample: " + bitsPerSample);
|
||||
sb.append("\n");
|
||||
sb.append("subChunk2Id: " + subChunk2Id);
|
||||
sb.append("\n");
|
||||
sb.append("subChunk2Size: " + subChunk2Size);
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.dsp;
|
||||
|
||||
import org.jtransforms.fft.DoubleFFT_1D;
|
||||
|
||||
public class FastFourierTransform {
|
||||
|
||||
/**
|
||||
* Get the frequency intensities
|
||||
*
|
||||
* @param amplitudes amplitudes of the signal. Format depends on value of complex
|
||||
* @param complex if true, amplitudes is assumed to be complex interlaced (re = even, im = odd), if false amplitudes
|
||||
* are assumed to be real valued.
|
||||
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
|
||||
*/
|
||||
public double[] getMagnitudes(double[] amplitudes, boolean complex) {
|
||||
|
||||
final int sampleSize = amplitudes.length;
|
||||
final int nrofFrequencyBins = sampleSize / 2;
|
||||
|
||||
|
||||
// call the fft and transform the complex numbers
|
||||
if (complex) {
|
||||
DoubleFFT_1D fft = new DoubleFFT_1D(nrofFrequencyBins);
|
||||
fft.complexForward(amplitudes);
|
||||
} else {
|
||||
DoubleFFT_1D fft = new DoubleFFT_1D(sampleSize);
|
||||
fft.realForward(amplitudes);
|
||||
// amplitudes[1] contains re[sampleSize/2] or im[(sampleSize-1) / 2] (depending on whether sampleSize is odd or even)
|
||||
// Discard it as it is useless without the other part
|
||||
// im part dc bin is always 0 for real input
|
||||
amplitudes[1] = 0;
|
||||
}
|
||||
// end call the fft and transform the complex numbers
|
||||
|
||||
// even indexes (0,2,4,6,...) are real parts
|
||||
// odd indexes (1,3,5,7,...) are img parts
|
||||
double[] mag = new double[nrofFrequencyBins];
|
||||
for (int i = 0; i < nrofFrequencyBins; i++) {
|
||||
final int f = 2 * i;
|
||||
mag[i] = Math.sqrt(amplitudes[f] * amplitudes[f] + amplitudes[f + 1] * amplitudes[f + 1]);
|
||||
}
|
||||
|
||||
return mag;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the frequency intensities. Backwards compatible with previous versions w.r.t to number of frequency bins.
|
||||
* Use getMagnitudes(amplitudes, true) to get all bins.
|
||||
*
|
||||
* @param amplitudes complex-valued signal to transform. Even indexes are real and odd indexes are img
|
||||
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
|
||||
*/
|
||||
public double[] getMagnitudes(double[] amplitudes) {
|
||||
double[] magnitudes = getMagnitudes(amplitudes, true);
|
||||
|
||||
double[] halfOfMagnitudes = new double[magnitudes.length/2];
|
||||
System.arraycopy(magnitudes, 0,halfOfMagnitudes, 0, halfOfMagnitudes.length);
|
||||
return halfOfMagnitudes;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.dsp;
|
||||
|
||||
public class LinearInterpolation {
|
||||
|
||||
public LinearInterpolation() {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Do interpolation on the samples according to the original and destinated sample rates
|
||||
*
|
||||
* @param oldSampleRate sample rate of the original samples
|
||||
* @param newSampleRate sample rate of the interpolated samples
|
||||
* @param samples original samples
|
||||
* @return interpolated samples
|
||||
*/
|
||||
public short[] interpolate(int oldSampleRate, int newSampleRate, short[] samples) {
|
||||
|
||||
if (oldSampleRate == newSampleRate) {
|
||||
return samples;
|
||||
}
|
||||
|
||||
int newLength = Math.round(((float) samples.length / oldSampleRate * newSampleRate));
|
||||
float lengthMultiplier = (float) newLength / samples.length;
|
||||
short[] interpolatedSamples = new short[newLength];
|
||||
|
||||
// interpolate the value by the linear equation y=mx+c
|
||||
for (int i = 0; i < newLength; i++) {
|
||||
|
||||
// get the nearest positions for the interpolated point
|
||||
float currentPosition = i / lengthMultiplier;
|
||||
int nearestLeftPosition = (int) currentPosition;
|
||||
int nearestRightPosition = nearestLeftPosition + 1;
|
||||
if (nearestRightPosition >= samples.length) {
|
||||
nearestRightPosition = samples.length - 1;
|
||||
}
|
||||
|
||||
float slope = samples[nearestRightPosition] - samples[nearestLeftPosition]; // delta x is 1
|
||||
float positionFromLeft = currentPosition - nearestLeftPosition;
|
||||
|
||||
interpolatedSamples[i] = (short) (slope * positionFromLeft + samples[nearestLeftPosition]); // y=mx+c
|
||||
}
|
||||
|
||||
return interpolatedSamples;
|
||||
}
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.dsp;
|
||||
|
||||
public class Resampler {
|
||||
|
||||
public Resampler() {}
|
||||
|
||||
/**
|
||||
* Do resampling. Currently the amplitude is stored by short such that maximum bitsPerSample is 16 (bytePerSample is 2)
|
||||
*
|
||||
* @param sourceData The source data in bytes
|
||||
* @param bitsPerSample How many bits represents one sample (currently supports max. bitsPerSample=16)
|
||||
* @param sourceRate Sample rate of the source data
|
||||
* @param targetRate Sample rate of the target data
|
||||
* @return re-sampled data
|
||||
*/
|
||||
public byte[] reSample(byte[] sourceData, int bitsPerSample, int sourceRate, int targetRate) {
|
||||
|
||||
// make the bytes to amplitudes first
|
||||
int bytePerSample = bitsPerSample / 8;
|
||||
int numSamples = sourceData.length / bytePerSample;
|
||||
short[] amplitudes = new short[numSamples]; // 16 bit, use a short to store
|
||||
|
||||
int pointer = 0;
|
||||
for (int i = 0; i < numSamples; i++) {
|
||||
short amplitude = 0;
|
||||
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
|
||||
// little endian
|
||||
amplitude |= (short) ((sourceData[pointer++] & 0xFF) << (byteNumber * 8));
|
||||
}
|
||||
amplitudes[i] = amplitude;
|
||||
}
|
||||
// end make the amplitudes
|
||||
|
||||
// do interpolation
|
||||
LinearInterpolation reSample = new LinearInterpolation();
|
||||
short[] targetSample = reSample.interpolate(sourceRate, targetRate, amplitudes);
|
||||
int targetLength = targetSample.length;
|
||||
// end do interpolation
|
||||
|
||||
// TODO: Remove the high frequency signals with a digital filter, leaving a signal containing only half-sample-rated frequency information, but still sampled at a rate of target sample rate. Usually FIR is used
|
||||
|
||||
// end resample the amplitudes
|
||||
|
||||
// convert the amplitude to bytes
|
||||
byte[] bytes;
|
||||
if (bytePerSample == 1) {
|
||||
bytes = new byte[targetLength];
|
||||
for (int i = 0; i < targetLength; i++) {
|
||||
bytes[i] = (byte) targetSample[i];
|
||||
}
|
||||
} else {
|
||||
// suppose bytePerSample==2
|
||||
bytes = new byte[targetLength * 2];
|
||||
for (int i = 0; i < targetSample.length; i++) {
|
||||
// little endian
|
||||
bytes[i * 2] = (byte) (targetSample[i] & 0xff);
|
||||
bytes[i * 2 + 1] = (byte) ((targetSample[i] >> 8) & 0xff);
|
||||
}
|
||||
}
|
||||
// end convert the amplitude to bytes
|
||||
|
||||
return bytes;
|
||||
}
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.dsp;
|
||||
|
||||
public class WindowFunction {
|
||||
|
||||
public static final int RECTANGULAR = 0;
|
||||
public static final int BARTLETT = 1;
|
||||
public static final int HANNING = 2;
|
||||
public static final int HAMMING = 3;
|
||||
public static final int BLACKMAN = 4;
|
||||
|
||||
int windowType = 0; // defaults to rectangular window
|
||||
|
||||
public WindowFunction() {}
|
||||
|
||||
public void setWindowType(int wt) {
|
||||
windowType = wt;
|
||||
}
|
||||
|
||||
public void setWindowType(String w) {
|
||||
if (w.toUpperCase().equals("RECTANGULAR"))
|
||||
windowType = RECTANGULAR;
|
||||
if (w.toUpperCase().equals("BARTLETT"))
|
||||
windowType = BARTLETT;
|
||||
if (w.toUpperCase().equals("HANNING"))
|
||||
windowType = HANNING;
|
||||
if (w.toUpperCase().equals("HAMMING"))
|
||||
windowType = HAMMING;
|
||||
if (w.toUpperCase().equals("BLACKMAN"))
|
||||
windowType = BLACKMAN;
|
||||
}
|
||||
|
||||
public int getWindowType() {
|
||||
return windowType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a window
|
||||
*
|
||||
* @param nSamples size of the window
|
||||
* @return window in array
|
||||
*/
|
||||
public double[] generate(int nSamples) {
|
||||
// generate nSamples window function values
|
||||
// for index values 0 .. nSamples - 1
|
||||
int m = nSamples / 2;
|
||||
double r;
|
||||
double pi = Math.PI;
|
||||
double[] w = new double[nSamples];
|
||||
switch (windowType) {
|
||||
case BARTLETT: // Bartlett (triangular) window
|
||||
for (int n = 0; n < nSamples; n++)
|
||||
w[n] = 1.0f - Math.abs(n - m) / m;
|
||||
break;
|
||||
case HANNING: // Hanning window
|
||||
r = pi / (m + 1);
|
||||
for (int n = -m; n < m; n++)
|
||||
w[m + n] = 0.5f + 0.5f * Math.cos(n * r);
|
||||
break;
|
||||
case HAMMING: // Hamming window
|
||||
r = pi / m;
|
||||
for (int n = -m; n < m; n++)
|
||||
w[m + n] = 0.54f + 0.46f * Math.cos(n * r);
|
||||
break;
|
||||
case BLACKMAN: // Blackman window
|
||||
r = pi / m;
|
||||
for (int n = -m; n < m; n++)
|
||||
w[m + n] = 0.42f + 0.5f * Math.cos(n * r) + 0.08f * Math.cos(2 * n * r);
|
||||
break;
|
||||
default: // Rectangular window function
|
||||
for (int n = 0; n < nSamples; n++)
|
||||
w[n] = 1.0f;
|
||||
}
|
||||
return w;
|
||||
}
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.dsp;
|
|
@ -1,67 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.extension;
|
||||
|
||||
|
||||
import org.datavec.audio.Wave;
|
||||
|
||||
public class NormalizedSampleAmplitudes {
|
||||
|
||||
private Wave wave;
|
||||
private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame
|
||||
|
||||
public NormalizedSampleAmplitudes(Wave wave) {
|
||||
this.wave = wave;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Get normalized amplitude of each frame
|
||||
*
|
||||
* @return array of normalized amplitudes(signed 16 bit): normalizedAmplitudes[frame]=amplitude
|
||||
*/
|
||||
public double[] getNormalizedAmplitudes() {
|
||||
|
||||
if (normalizedAmplitudes == null) {
|
||||
|
||||
boolean signed = true;
|
||||
|
||||
// usually 8bit is unsigned
|
||||
if (wave.getWaveHeader().getBitsPerSample() == 8) {
|
||||
signed = false;
|
||||
}
|
||||
|
||||
short[] amplitudes = wave.getSampleAmplitudes();
|
||||
int numSamples = amplitudes.length;
|
||||
int maxAmplitude = 1 << (wave.getWaveHeader().getBitsPerSample() - 1);
|
||||
|
||||
if (!signed) { // one more bit for unsigned value
|
||||
maxAmplitude <<= 1;
|
||||
}
|
||||
|
||||
normalizedAmplitudes = new double[numSamples];
|
||||
for (int i = 0; i < numSamples; i++) {
|
||||
normalizedAmplitudes[i] = (double) amplitudes[i] / maxAmplitude;
|
||||
}
|
||||
}
|
||||
return normalizedAmplitudes;
|
||||
}
|
||||
}
|
|
@ -1,214 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.extension;
|
||||
|
||||
|
||||
import org.datavec.audio.Wave;
|
||||
import org.datavec.audio.dsp.FastFourierTransform;
|
||||
import org.datavec.audio.dsp.WindowFunction;
|
||||
|
||||
public class Spectrogram {
|
||||
|
||||
public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024;
|
||||
public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping
|
||||
|
||||
private Wave wave;
|
||||
private double[][] spectrogram; // relative spectrogram
|
||||
private double[][] absoluteSpectrogram; // absolute spectrogram
|
||||
private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2
|
||||
private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping
|
||||
private int numFrames; // number of frames of the spectrogram
|
||||
private int framesPerSecond; // frame per second of the spectrogram
|
||||
private int numFrequencyUnit; // number of y-axis unit
|
||||
private double unitFrequency; // frequency per y-axis unit
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param wave
|
||||
*/
|
||||
public Spectrogram(Wave wave) {
|
||||
this.wave = wave;
|
||||
// default
|
||||
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
|
||||
this.overlapFactor = SPECTROGRAM_DEFAULT_OVERLAP_FACTOR;
|
||||
buildSpectrogram();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param wave
|
||||
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
|
||||
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
|
||||
*/
|
||||
public Spectrogram(Wave wave, int fftSampleSize, int overlapFactor) {
|
||||
this.wave = wave;
|
||||
|
||||
if (Integer.bitCount(fftSampleSize) == 1) {
|
||||
this.fftSampleSize = fftSampleSize;
|
||||
} else {
|
||||
System.err.print("The input number must be a power of 2");
|
||||
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
|
||||
}
|
||||
|
||||
this.overlapFactor = overlapFactor;
|
||||
|
||||
buildSpectrogram();
|
||||
}
|
||||
|
||||
/**
|
||||
* Build spectrogram
|
||||
*/
|
||||
private void buildSpectrogram() {
|
||||
|
||||
short[] amplitudes = wave.getSampleAmplitudes();
|
||||
int numSamples = amplitudes.length;
|
||||
|
||||
int pointer = 0;
|
||||
// overlapping
|
||||
if (overlapFactor > 1) {
|
||||
int numOverlappedSamples = numSamples * overlapFactor;
|
||||
int backSamples = fftSampleSize * (overlapFactor - 1) / overlapFactor;
|
||||
short[] overlapAmp = new short[numOverlappedSamples];
|
||||
pointer = 0;
|
||||
for (int i = 0; i < amplitudes.length; i++) {
|
||||
overlapAmp[pointer++] = amplitudes[i];
|
||||
if (pointer % fftSampleSize == 0) {
|
||||
// overlap
|
||||
i -= backSamples;
|
||||
}
|
||||
}
|
||||
numSamples = numOverlappedSamples;
|
||||
amplitudes = overlapAmp;
|
||||
}
|
||||
// end overlapping
|
||||
|
||||
numFrames = numSamples / fftSampleSize;
|
||||
framesPerSecond = (int) (numFrames / wave.length());
|
||||
|
||||
// set signals for fft
|
||||
WindowFunction window = new WindowFunction();
|
||||
window.setWindowType("Hamming");
|
||||
double[] win = window.generate(fftSampleSize);
|
||||
|
||||
double[][] signals = new double[numFrames][];
|
||||
for (int f = 0; f < numFrames; f++) {
|
||||
signals[f] = new double[fftSampleSize];
|
||||
int startSample = f * fftSampleSize;
|
||||
for (int n = 0; n < fftSampleSize; n++) {
|
||||
signals[f][n] = amplitudes[startSample + n] * win[n];
|
||||
}
|
||||
}
|
||||
// end set signals for fft
|
||||
|
||||
absoluteSpectrogram = new double[numFrames][];
|
||||
// for each frame in signals, do fft on it
|
||||
FastFourierTransform fft = new FastFourierTransform();
|
||||
for (int i = 0; i < numFrames; i++) {
|
||||
absoluteSpectrogram[i] = fft.getMagnitudes(signals[i], false);
|
||||
}
|
||||
|
||||
if (absoluteSpectrogram.length > 0) {
|
||||
|
||||
numFrequencyUnit = absoluteSpectrogram[0].length;
|
||||
unitFrequency = (double) wave.getWaveHeader().getSampleRate() / 2 / numFrequencyUnit; // frequency could be caught within the half of nSamples according to Nyquist theory
|
||||
|
||||
// normalization of absoultSpectrogram
|
||||
spectrogram = new double[numFrames][numFrequencyUnit];
|
||||
|
||||
// set max and min amplitudes
|
||||
double maxAmp = Double.MIN_VALUE;
|
||||
double minAmp = Double.MAX_VALUE;
|
||||
for (int i = 0; i < numFrames; i++) {
|
||||
for (int j = 0; j < numFrequencyUnit; j++) {
|
||||
if (absoluteSpectrogram[i][j] > maxAmp) {
|
||||
maxAmp = absoluteSpectrogram[i][j];
|
||||
} else if (absoluteSpectrogram[i][j] < minAmp) {
|
||||
minAmp = absoluteSpectrogram[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
// end set max and min amplitudes
|
||||
|
||||
// normalization
|
||||
// avoiding divided by zero
|
||||
double minValidAmp = 0.00000000001F;
|
||||
if (minAmp == 0) {
|
||||
minAmp = minValidAmp;
|
||||
}
|
||||
|
||||
double diff = Math.log10(maxAmp / minAmp); // perceptual difference
|
||||
for (int i = 0; i < numFrames; i++) {
|
||||
for (int j = 0; j < numFrequencyUnit; j++) {
|
||||
if (absoluteSpectrogram[i][j] < minValidAmp) {
|
||||
spectrogram[i][j] = 0;
|
||||
} else {
|
||||
spectrogram[i][j] = (Math.log10(absoluteSpectrogram[i][j] / minAmp)) / diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
// end normalization
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get spectrogram: spectrogram[time][frequency]=intensity
|
||||
*
|
||||
* @return logarithm normalized spectrogram
|
||||
*/
|
||||
public double[][] getNormalizedSpectrogramData() {
|
||||
return spectrogram;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get spectrogram: spectrogram[time][frequency]=intensity
|
||||
*
|
||||
* @return absolute spectrogram
|
||||
*/
|
||||
public double[][] getAbsoluteSpectrogramData() {
|
||||
return absoluteSpectrogram;
|
||||
}
|
||||
|
||||
public int getNumFrames() {
|
||||
return numFrames;
|
||||
}
|
||||
|
||||
public int getFramesPerSecond() {
|
||||
return framesPerSecond;
|
||||
}
|
||||
|
||||
public int getNumFrequencyUnit() {
|
||||
return numFrequencyUnit;
|
||||
}
|
||||
|
||||
public double getUnitFrequency() {
|
||||
return unitFrequency;
|
||||
}
|
||||
|
||||
public int getFftSampleSize() {
|
||||
return fftSampleSize;
|
||||
}
|
||||
|
||||
public int getOverlapFactor() {
|
||||
return overlapFactor;
|
||||
}
|
||||
}
|
|
@ -1,272 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.datavec.audio.Wave;
|
||||
import org.datavec.audio.WaveHeader;
|
||||
import org.datavec.audio.dsp.Resampler;
|
||||
import org.datavec.audio.extension.Spectrogram;
|
||||
import org.datavec.audio.processor.TopManyPointsProcessorChain;
|
||||
import org.datavec.audio.properties.FingerprintProperties;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class FingerprintManager {
|
||||
|
||||
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
||||
private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame();
|
||||
private int overlapFactor = fingerprintProperties.getOverlapFactor();
|
||||
private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame();
|
||||
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public FingerprintManager() {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract fingerprint from Wave object
|
||||
*
|
||||
* @param wave Wave Object to be extracted fingerprint
|
||||
* @return fingerprint in bytes
|
||||
*/
|
||||
public byte[] extractFingerprint(Wave wave) {
|
||||
|
||||
int[][] coordinates; // coordinates[x][0..3]=y0..y3
|
||||
byte[] fingerprint = new byte[0];
|
||||
|
||||
// resample to target rate
|
||||
Resampler resampler = new Resampler();
|
||||
int sourceRate = wave.getWaveHeader().getSampleRate();
|
||||
int targetRate = fingerprintProperties.getSampleRate();
|
||||
|
||||
byte[] resampledWaveData = resampler.reSample(wave.getBytes(), wave.getWaveHeader().getBitsPerSample(),
|
||||
sourceRate, targetRate);
|
||||
|
||||
// update the wave header
|
||||
WaveHeader resampledWaveHeader = wave.getWaveHeader();
|
||||
resampledWaveHeader.setSampleRate(targetRate);
|
||||
|
||||
// make resampled wave
|
||||
Wave resampledWave = new Wave(resampledWaveHeader, resampledWaveData);
|
||||
// end resample to target rate
|
||||
|
||||
// get spectrogram's data
|
||||
Spectrogram spectrogram = resampledWave.getSpectrogram(sampleSizePerFrame, overlapFactor);
|
||||
double[][] spectorgramData = spectrogram.getNormalizedSpectrogramData();
|
||||
|
||||
List<Integer>[] pointsLists = getRobustPointList(spectorgramData);
|
||||
int numFrames = pointsLists.length;
|
||||
|
||||
// prepare fingerprint bytes
|
||||
coordinates = new int[numFrames][numRobustPointsPerFrame];
|
||||
|
||||
for (int x = 0; x < numFrames; x++) {
|
||||
if (pointsLists[x].size() == numRobustPointsPerFrame) {
|
||||
Iterator<Integer> pointsListsIterator = pointsLists[x].iterator();
|
||||
for (int y = 0; y < numRobustPointsPerFrame; y++) {
|
||||
coordinates[x][y] = pointsListsIterator.next();
|
||||
}
|
||||
} else {
|
||||
// use -1 to fill the empty byte
|
||||
for (int y = 0; y < numRobustPointsPerFrame; y++) {
|
||||
coordinates[x][y] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// end make fingerprint
|
||||
|
||||
// for each valid coordinate, append with its intensity
|
||||
List<Byte> byteList = new LinkedList<Byte>();
|
||||
for (int i = 0; i < numFrames; i++) {
|
||||
for (int j = 0; j < numRobustPointsPerFrame; j++) {
|
||||
if (coordinates[i][j] != -1) {
|
||||
// first 2 bytes is x
|
||||
byteList.add((byte) (i >> 8));
|
||||
byteList.add((byte) i);
|
||||
|
||||
// next 2 bytes is y
|
||||
int y = coordinates[i][j];
|
||||
byteList.add((byte) (y >> 8));
|
||||
byteList.add((byte) y);
|
||||
|
||||
// next 4 bytes is intensity
|
||||
int intensity = (int) (spectorgramData[i][y] * Integer.MAX_VALUE); // spectorgramData is ranged from 0~1
|
||||
byteList.add((byte) (intensity >> 24));
|
||||
byteList.add((byte) (intensity >> 16));
|
||||
byteList.add((byte) (intensity >> 8));
|
||||
byteList.add((byte) intensity);
|
||||
}
|
||||
}
|
||||
}
|
||||
// end for each valid coordinate, append with its intensity
|
||||
|
||||
fingerprint = new byte[byteList.size()];
|
||||
Iterator<Byte> byteListIterator = byteList.iterator();
|
||||
int pointer = 0;
|
||||
while (byteListIterator.hasNext()) {
|
||||
fingerprint[pointer++] = byteListIterator.next();
|
||||
}
|
||||
|
||||
return fingerprint;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get bytes from fingerprint file
|
||||
*
|
||||
* @param fingerprintFile fingerprint filename
|
||||
* @return fingerprint in bytes
|
||||
*/
|
||||
public byte[] getFingerprintFromFile(String fingerprintFile) {
|
||||
byte[] fingerprint = null;
|
||||
try {
|
||||
InputStream fis = new FileInputStream(fingerprintFile);
|
||||
fingerprint = getFingerprintFromInputStream(fis);
|
||||
fis.close();
|
||||
} catch (IOException e) {
|
||||
log.error("",e);
|
||||
}
|
||||
return fingerprint;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get bytes from fingerprint inputstream
|
||||
*
|
||||
* @param inputStream fingerprint inputstream
|
||||
* @return fingerprint in bytes
|
||||
*/
|
||||
public byte[] getFingerprintFromInputStream(InputStream inputStream) {
|
||||
byte[] fingerprint = null;
|
||||
try {
|
||||
fingerprint = new byte[inputStream.available()];
|
||||
inputStream.read(fingerprint);
|
||||
} catch (IOException e) {
|
||||
log.error("",e);
|
||||
}
|
||||
return fingerprint;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save fingerprint to a file
|
||||
*
|
||||
* @param fingerprint fingerprint bytes
|
||||
* @param filename fingerprint filename
|
||||
* @see FingerprintManager file saved
|
||||
*/
|
||||
public void saveFingerprintAsFile(byte[] fingerprint, String filename) {
|
||||
|
||||
FileOutputStream fileOutputStream;
|
||||
try {
|
||||
fileOutputStream = new FileOutputStream(filename);
|
||||
fileOutputStream.write(fingerprint);
|
||||
fileOutputStream.close();
|
||||
} catch (IOException e) {
|
||||
log.error("",e);
|
||||
}
|
||||
}
|
||||
|
||||
// robustLists[x]=y1,y2,y3,...
|
||||
private List<Integer>[] getRobustPointList(double[][] spectrogramData) {
|
||||
|
||||
int numX = spectrogramData.length;
|
||||
int numY = spectrogramData[0].length;
|
||||
|
||||
double[][] allBanksIntensities = new double[numX][numY];
|
||||
int bandwidthPerBank = numY / numFilterBanks;
|
||||
|
||||
for (int b = 0; b < numFilterBanks; b++) {
|
||||
|
||||
double[][] bankIntensities = new double[numX][bandwidthPerBank];
|
||||
|
||||
for (int i = 0; i < numX; i++) {
|
||||
System.arraycopy(spectrogramData[i], b * bandwidthPerBank, bankIntensities[i], 0, bandwidthPerBank);
|
||||
}
|
||||
|
||||
// get the most robust point in each filter bank
|
||||
TopManyPointsProcessorChain processorChain = new TopManyPointsProcessorChain(bankIntensities, 1);
|
||||
double[][] processedIntensities = processorChain.getIntensities();
|
||||
|
||||
for (int i = 0; i < numX; i++) {
|
||||
System.arraycopy(processedIntensities[i], 0, allBanksIntensities[i], b * bandwidthPerBank,
|
||||
bandwidthPerBank);
|
||||
}
|
||||
}
|
||||
|
||||
List<int[]> robustPointList = new LinkedList<int[]>();
|
||||
|
||||
// find robust points
|
||||
for (int i = 0; i < allBanksIntensities.length; i++) {
|
||||
for (int j = 0; j < allBanksIntensities[i].length; j++) {
|
||||
if (allBanksIntensities[i][j] > 0) {
|
||||
|
||||
int[] point = new int[] {i, j};
|
||||
//System.out.println(i+","+frequency);
|
||||
robustPointList.add(point);
|
||||
}
|
||||
}
|
||||
}
|
||||
// end find robust points
|
||||
|
||||
List<Integer>[] robustLists = new LinkedList[spectrogramData.length];
|
||||
for (int i = 0; i < robustLists.length; i++) {
|
||||
robustLists[i] = new LinkedList<>();
|
||||
}
|
||||
|
||||
// robustLists[x]=y1,y2,y3,...
|
||||
for (int[] coor : robustPointList) {
|
||||
robustLists[coor[0]].add(coor[1]);
|
||||
}
|
||||
|
||||
// return the list per frame
|
||||
return robustLists;
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of frames in a fingerprint
|
||||
* Each frame lengths 8 bytes
|
||||
* Usually there is more than one point in each frame, so it cannot simply divide the bytes length by 8
|
||||
* Last 8 byte of thisFingerprint is the last frame of this wave
|
||||
* First 2 byte of the last 8 byte is the x position of this wave, i.e. (number_of_frames-1) of this wave
|
||||
*
|
||||
* @param fingerprint fingerprint bytes
|
||||
* @return number of frames of the fingerprint
|
||||
*/
|
||||
public static int getNumFrames(byte[] fingerprint) {
|
||||
|
||||
if (fingerprint.length < 8) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get the last x-coordinate (length-8&length-7)bytes from fingerprint
|
||||
return ((fingerprint[fingerprint.length - 8] & 0xff) << 8 | (fingerprint[fingerprint.length - 7] & 0xff)) + 1;
|
||||
}
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
|
||||
import org.datavec.audio.properties.FingerprintProperties;
|
||||
|
||||
public class FingerprintSimilarity {
|
||||
|
||||
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
||||
private int mostSimilarFramePosition;
|
||||
private float score;
|
||||
private float similarity;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public FingerprintSimilarity() {
|
||||
mostSimilarFramePosition = Integer.MIN_VALUE;
|
||||
score = -1;
|
||||
similarity = -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the most similar position in terms of frame number
|
||||
*
|
||||
* @return most similar frame position
|
||||
*/
|
||||
public int getMostSimilarFramePosition() {
|
||||
return mostSimilarFramePosition;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the most similar position in terms of frame number
|
||||
*
|
||||
* @param mostSimilarFramePosition
|
||||
*/
|
||||
public void setMostSimilarFramePosition(int mostSimilarFramePosition) {
|
||||
this.mostSimilarFramePosition = mostSimilarFramePosition;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the similarity of the fingerprints
|
||||
* similarity from 0~1, which 0 means no similar feature is found and 1 means in average there is at least one match in every frame
|
||||
*
|
||||
* @return fingerprints similarity
|
||||
*/
|
||||
public float getSimilarity() {
|
||||
return similarity;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the similarity of the fingerprints
|
||||
*
|
||||
* @param similarity similarity
|
||||
*/
|
||||
public void setSimilarity(float similarity) {
|
||||
this.similarity = similarity;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the similarity score of the fingerprints
|
||||
* Number of features found in the fingerprints per frame
|
||||
*
|
||||
* @return fingerprints similarity score
|
||||
*/
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the similarity score of the fingerprints
|
||||
*
|
||||
* @param score
|
||||
*/
|
||||
public void setScore(float score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the most similar position in terms of time in second
|
||||
*
|
||||
* @return most similar starting time
|
||||
*/
|
||||
public float getsetMostSimilarTimePosition() {
|
||||
return (float) mostSimilarFramePosition / fingerprintProperties.getNumRobustPointsPerFrame()
|
||||
/ fingerprintProperties.getFps();
|
||||
}
|
||||
}
|
|
@ -1,134 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
public class FingerprintSimilarityComputer {
|
||||
|
||||
private FingerprintSimilarity fingerprintSimilarity;
|
||||
byte[] fingerprint1, fingerprint2;
|
||||
|
||||
/**
|
||||
* Constructor, ready to compute the similarity of two fingerprints
|
||||
*
|
||||
* @param fingerprint1
|
||||
* @param fingerprint2
|
||||
*/
|
||||
public FingerprintSimilarityComputer(byte[] fingerprint1, byte[] fingerprint2) {
|
||||
|
||||
this.fingerprint1 = fingerprint1;
|
||||
this.fingerprint2 = fingerprint2;
|
||||
|
||||
fingerprintSimilarity = new FingerprintSimilarity();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get fingerprint similarity of inout fingerprints
|
||||
*
|
||||
* @return fingerprint similarity object
|
||||
*/
|
||||
public FingerprintSimilarity getFingerprintsSimilarity() {
|
||||
HashMap<Integer, Integer> offset_Score_Table = new HashMap<>(); // offset_Score_Table<offset,count>
|
||||
int numFrames;
|
||||
float score = 0;
|
||||
int mostSimilarFramePosition = Integer.MIN_VALUE;
|
||||
|
||||
// one frame may contain several points, use the shorter one be the denominator
|
||||
if (fingerprint1.length > fingerprint2.length) {
|
||||
numFrames = FingerprintManager.getNumFrames(fingerprint2);
|
||||
} else {
|
||||
numFrames = FingerprintManager.getNumFrames(fingerprint1);
|
||||
}
|
||||
|
||||
// get the pairs
|
||||
PairManager pairManager = new PairManager();
|
||||
HashMap<Integer, List<Integer>> this_Pair_PositionList_Table =
|
||||
pairManager.getPair_PositionList_Table(fingerprint1);
|
||||
HashMap<Integer, List<Integer>> compareWave_Pair_PositionList_Table =
|
||||
pairManager.getPair_PositionList_Table(fingerprint2);
|
||||
|
||||
for (Integer compareWaveHashNumber : compareWave_Pair_PositionList_Table.keySet()) {
|
||||
// if the compareWaveHashNumber doesn't exist in both tables, no need to compare
|
||||
if (!this_Pair_PositionList_Table.containsKey(compareWaveHashNumber)
|
||||
|| !compareWave_Pair_PositionList_Table.containsKey(compareWaveHashNumber)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// for each compare hash number, get the positions
|
||||
List<Integer> wavePositionList = this_Pair_PositionList_Table.get(compareWaveHashNumber);
|
||||
List<Integer> compareWavePositionList = compareWave_Pair_PositionList_Table.get(compareWaveHashNumber);
|
||||
|
||||
for (Integer thisPosition : wavePositionList) {
|
||||
for (Integer compareWavePosition : compareWavePositionList) {
|
||||
int offset = thisPosition - compareWavePosition;
|
||||
if (offset_Score_Table.containsKey(offset)) {
|
||||
offset_Score_Table.put(offset, offset_Score_Table.get(offset) + 1);
|
||||
} else {
|
||||
offset_Score_Table.put(offset, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// map rank
|
||||
MapRank mapRank = new MapRankInteger(offset_Score_Table, false);
|
||||
|
||||
// get the most similar positions and scores
|
||||
List<Integer> orderedKeyList = mapRank.getOrderedKeyList(100, true);
|
||||
if (orderedKeyList.size() > 0) {
|
||||
int key = orderedKeyList.get(0);
|
||||
// get the highest score position
|
||||
mostSimilarFramePosition = key;
|
||||
score = offset_Score_Table.get(key);
|
||||
|
||||
// accumulate the scores from neighbours
|
||||
if (offset_Score_Table.containsKey(key - 1)) {
|
||||
score += offset_Score_Table.get(key - 1) / 2;
|
||||
}
|
||||
if (offset_Score_Table.containsKey(key + 1)) {
|
||||
score += offset_Score_Table.get(key + 1) / 2;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Iterator<Integer> orderedKeyListIterator=orderedKeyList.iterator();
|
||||
while (orderedKeyListIterator.hasNext()){
|
||||
int offset=orderedKeyListIterator.next();
|
||||
System.out.println(offset+": "+offset_Score_Table.get(offset));
|
||||
}
|
||||
*/
|
||||
|
||||
score /= numFrames;
|
||||
float similarity = score;
|
||||
// similarity >1 means in average there is at least one match in every frame
|
||||
if (similarity > 1) {
|
||||
similarity = 1;
|
||||
}
|
||||
|
||||
fingerprintSimilarity.setMostSimilarFramePosition(mostSimilarFramePosition);
|
||||
fingerprintSimilarity.setScore(score);
|
||||
fingerprintSimilarity.setSimilarity(similarity);
|
||||
|
||||
return fingerprintSimilarity;
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface MapRank {
|
||||
public List getOrderedKeyList(int numKeys, boolean sharpLimit);
|
||||
}
|
|
@ -1,179 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Map.Entry;
|
||||
|
||||
public class MapRankDouble implements MapRank {
|
||||
|
||||
private Map map;
|
||||
private boolean acsending = true;
|
||||
|
||||
public MapRankDouble(Map<?, Double> map, boolean acsending) {
|
||||
this.map = map;
|
||||
this.acsending = acsending;
|
||||
}
|
||||
|
||||
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
|
||||
|
||||
Set mapEntrySet = map.entrySet();
|
||||
List keyList = new LinkedList();
|
||||
|
||||
// if the numKeys is larger than map size, limit it
|
||||
if (numKeys > map.size()) {
|
||||
numKeys = map.size();
|
||||
}
|
||||
// end if the numKeys is larger than map size, limit it
|
||||
|
||||
if (map.size() > 0) {
|
||||
double[] array = new double[map.size()];
|
||||
int count = 0;
|
||||
|
||||
// get the pass values
|
||||
Iterator<Entry> mapIterator = mapEntrySet.iterator();
|
||||
while (mapIterator.hasNext()) {
|
||||
Entry entry = mapIterator.next();
|
||||
array[count++] = (Double) entry.getValue();
|
||||
}
|
||||
// end get the pass values
|
||||
|
||||
int targetindex;
|
||||
if (acsending) {
|
||||
targetindex = numKeys;
|
||||
} else {
|
||||
targetindex = array.length - numKeys;
|
||||
}
|
||||
|
||||
double passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
|
||||
// get the passed keys and values
|
||||
Map passedMap = new HashMap();
|
||||
List<Double> valueList = new LinkedList<Double>();
|
||||
mapIterator = mapEntrySet.iterator();
|
||||
|
||||
while (mapIterator.hasNext()) {
|
||||
Entry entry = mapIterator.next();
|
||||
double value = (Double) entry.getValue();
|
||||
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
|
||||
passedMap.put(entry.getKey(), value);
|
||||
valueList.add(value);
|
||||
}
|
||||
}
|
||||
// end get the passed keys and values
|
||||
|
||||
// sort the value list
|
||||
Double[] listArr = new Double[valueList.size()];
|
||||
valueList.toArray(listArr);
|
||||
Arrays.sort(listArr);
|
||||
// end sort the value list
|
||||
|
||||
// get the list of keys
|
||||
int resultCount = 0;
|
||||
int index;
|
||||
if (acsending) {
|
||||
index = 0;
|
||||
} else {
|
||||
index = listArr.length - 1;
|
||||
}
|
||||
|
||||
if (!sharpLimit) {
|
||||
numKeys = listArr.length;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
double targetValue = (Double) listArr[index];
|
||||
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
|
||||
while (passedMapIterator.hasNext()) {
|
||||
Entry entry = passedMapIterator.next();
|
||||
if ((Double) entry.getValue() == targetValue) {
|
||||
keyList.add(entry.getKey());
|
||||
passedMapIterator.remove();
|
||||
resultCount++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (acsending) {
|
||||
index++;
|
||||
} else {
|
||||
index--;
|
||||
}
|
||||
|
||||
if (resultCount >= numKeys) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// end get the list of keys
|
||||
}
|
||||
|
||||
return keyList;
|
||||
}
|
||||
|
||||
private double getOrderedValue(double[] array, int index) {
|
||||
locate(array, 0, array.length - 1, index);
|
||||
return array[index];
|
||||
}
|
||||
|
||||
// sort the partitions by quick sort, and locate the target index
|
||||
private void locate(double[] array, int left, int right, int index) {
|
||||
|
||||
int mid = (left + right) / 2;
|
||||
//System.out.println(left+" to "+right+" ("+mid+")");
|
||||
|
||||
if (right == left) {
|
||||
//System.out.println("* "+array[targetIndex]);
|
||||
//result=array[targetIndex];
|
||||
return;
|
||||
}
|
||||
|
||||
if (left < right) {
|
||||
double s = array[mid];
|
||||
int i = left - 1;
|
||||
int j = right + 1;
|
||||
|
||||
while (true) {
|
||||
while (array[++i] < s);
|
||||
while (array[--j] > s);
|
||||
if (i >= j)
|
||||
break;
|
||||
swap(array, i, j);
|
||||
}
|
||||
|
||||
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
|
||||
|
||||
if (i > index) {
|
||||
// the target index in the left partition
|
||||
//System.out.println("left partition");
|
||||
locate(array, left, i - 1, index);
|
||||
} else {
|
||||
// the target index in the right partition
|
||||
//System.out.println("right partition");
|
||||
locate(array, j + 1, right, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void swap(double[] array, int i, int j) {
|
||||
double t = array[i];
|
||||
array[i] = array[j];
|
||||
array[j] = t;
|
||||
}
|
||||
}
|
|
@ -1,179 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Map.Entry;
|
||||
|
||||
public class MapRankInteger implements MapRank {
|
||||
|
||||
private Map map;
|
||||
private boolean acsending = true;
|
||||
|
||||
public MapRankInteger(Map<?, Integer> map, boolean acsending) {
|
||||
this.map = map;
|
||||
this.acsending = acsending;
|
||||
}
|
||||
|
||||
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
|
||||
|
||||
Set mapEntrySet = map.entrySet();
|
||||
List keyList = new LinkedList();
|
||||
|
||||
// if the numKeys is larger than map size, limit it
|
||||
if (numKeys > map.size()) {
|
||||
numKeys = map.size();
|
||||
}
|
||||
// end if the numKeys is larger than map size, limit it
|
||||
|
||||
if (map.size() > 0) {
|
||||
int[] array = new int[map.size()];
|
||||
int count = 0;
|
||||
|
||||
// get the pass values
|
||||
Iterator<Entry> mapIterator = mapEntrySet.iterator();
|
||||
while (mapIterator.hasNext()) {
|
||||
Entry entry = mapIterator.next();
|
||||
array[count++] = (Integer) entry.getValue();
|
||||
}
|
||||
// end get the pass values
|
||||
|
||||
int targetindex;
|
||||
if (acsending) {
|
||||
targetindex = numKeys;
|
||||
} else {
|
||||
targetindex = array.length - numKeys;
|
||||
}
|
||||
|
||||
int passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
|
||||
// get the passed keys and values
|
||||
Map passedMap = new HashMap();
|
||||
List<Integer> valueList = new LinkedList<Integer>();
|
||||
mapIterator = mapEntrySet.iterator();
|
||||
|
||||
while (mapIterator.hasNext()) {
|
||||
Entry entry = mapIterator.next();
|
||||
int value = (Integer) entry.getValue();
|
||||
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
|
||||
passedMap.put(entry.getKey(), value);
|
||||
valueList.add(value);
|
||||
}
|
||||
}
|
||||
// end get the passed keys and values
|
||||
|
||||
// sort the value list
|
||||
Integer[] listArr = new Integer[valueList.size()];
|
||||
valueList.toArray(listArr);
|
||||
Arrays.sort(listArr);
|
||||
// end sort the value list
|
||||
|
||||
// get the list of keys
|
||||
int resultCount = 0;
|
||||
int index;
|
||||
if (acsending) {
|
||||
index = 0;
|
||||
} else {
|
||||
index = listArr.length - 1;
|
||||
}
|
||||
|
||||
if (!sharpLimit) {
|
||||
numKeys = listArr.length;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
int targetValue = (Integer) listArr[index];
|
||||
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
|
||||
while (passedMapIterator.hasNext()) {
|
||||
Entry entry = passedMapIterator.next();
|
||||
if ((Integer) entry.getValue() == targetValue) {
|
||||
keyList.add(entry.getKey());
|
||||
passedMapIterator.remove();
|
||||
resultCount++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (acsending) {
|
||||
index++;
|
||||
} else {
|
||||
index--;
|
||||
}
|
||||
|
||||
if (resultCount >= numKeys) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// end get the list of keys
|
||||
}
|
||||
|
||||
return keyList;
|
||||
}
|
||||
|
||||
private int getOrderedValue(int[] array, int index) {
|
||||
locate(array, 0, array.length - 1, index);
|
||||
return array[index];
|
||||
}
|
||||
|
||||
// sort the partitions by quick sort, and locate the target index
|
||||
private void locate(int[] array, int left, int right, int index) {
|
||||
|
||||
int mid = (left + right) / 2;
|
||||
//System.out.println(left+" to "+right+" ("+mid+")");
|
||||
|
||||
if (right == left) {
|
||||
//System.out.println("* "+array[targetIndex]);
|
||||
//result=array[targetIndex];
|
||||
return;
|
||||
}
|
||||
|
||||
if (left < right) {
|
||||
int s = array[mid];
|
||||
int i = left - 1;
|
||||
int j = right + 1;
|
||||
|
||||
while (true) {
|
||||
while (array[++i] < s);
|
||||
while (array[--j] > s);
|
||||
if (i >= j)
|
||||
break;
|
||||
swap(array, i, j);
|
||||
}
|
||||
|
||||
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
|
||||
|
||||
if (i > index) {
|
||||
// the target index in the left partition
|
||||
//System.out.println("left partition");
|
||||
locate(array, left, i - 1, index);
|
||||
} else {
|
||||
// the target index in the right partition
|
||||
//System.out.println("right partition");
|
||||
locate(array, j + 1, right, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void swap(int[] array, int i, int j) {
|
||||
int t = array[i];
|
||||
array[i] = array[j];
|
||||
array[j] = t;
|
||||
}
|
||||
}
|
|
@ -1,230 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
|
||||
|
||||
import org.datavec.audio.properties.FingerprintProperties;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
public class PairManager {
|
||||
|
||||
FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
||||
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
|
||||
private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks;
|
||||
private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength();
|
||||
private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval();
|
||||
private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance();
|
||||
private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits();
|
||||
|
||||
private int maxPairs;
|
||||
private boolean isReferencePairing;
|
||||
private HashMap<Integer, Boolean> stopPairTable = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public PairManager() {
|
||||
maxPairs = fingerprintProperties.getRefMaxActivePairs();
|
||||
isReferencePairing = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor, number of pairs of robust points depends on the parameter isReferencePairing
|
||||
* no. of pairs of reference and sample can be different due to environmental influence of source
|
||||
* @param isReferencePairing
|
||||
*/
|
||||
public PairManager(boolean isReferencePairing) {
|
||||
if (isReferencePairing) {
|
||||
maxPairs = fingerprintProperties.getRefMaxActivePairs();
|
||||
} else {
|
||||
maxPairs = fingerprintProperties.getSampleMaxActivePairs();
|
||||
}
|
||||
this.isReferencePairing = isReferencePairing;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a pair-positionList table
|
||||
* It's a hash map which the key is the hashed pair, and the value is list of positions
|
||||
* That means the table stores the positions which have the same hashed pair
|
||||
*
|
||||
* @param fingerprint fingerprint bytes
|
||||
* @return pair-positionList HashMap
|
||||
*/
|
||||
public HashMap<Integer, List<Integer>> getPair_PositionList_Table(byte[] fingerprint) {
|
||||
|
||||
List<int[]> pairPositionList = getPairPositionList(fingerprint);
|
||||
|
||||
// table to store pair:pos,pos,pos,...;pair2:pos,pos,pos,....
|
||||
HashMap<Integer, List<Integer>> pair_positionList_table = new HashMap<>();
|
||||
|
||||
// get all pair_positions from list, use a table to collect the data group by pair hashcode
|
||||
for (int[] pair_position : pairPositionList) {
|
||||
//System.out.println(pair_position[0]+","+pair_position[1]);
|
||||
|
||||
// group by pair-hashcode, i.e.: <pair,List<position>>
|
||||
if (pair_positionList_table.containsKey(pair_position[0])) {
|
||||
pair_positionList_table.get(pair_position[0]).add(pair_position[1]);
|
||||
} else {
|
||||
List<Integer> positionList = new LinkedList<>();
|
||||
positionList.add(pair_position[1]);
|
||||
pair_positionList_table.put(pair_position[0], positionList);
|
||||
}
|
||||
// end group by pair-hashcode, i.e.: <pair,List<position>>
|
||||
}
|
||||
// end get all pair_positions from list, use a table to collect the data group by pair hashcode
|
||||
|
||||
return pair_positionList_table;
|
||||
}
|
||||
|
||||
// this return list contains: int[0]=pair_hashcode, int[1]=position
|
||||
private List<int[]> getPairPositionList(byte[] fingerprint) {
|
||||
|
||||
int numFrames = FingerprintManager.getNumFrames(fingerprint);
|
||||
|
||||
// table for paired frames
|
||||
byte[] pairedFrameTable = new byte[numFrames / anchorPointsIntervalLength + 1]; // each second has numAnchorPointsPerSecond pairs only
|
||||
// end table for paired frames
|
||||
|
||||
List<int[]> pairList = new LinkedList<>();
|
||||
List<int[]> sortedCoordinateList = getSortedCoordinateList(fingerprint);
|
||||
|
||||
for (int[] anchorPoint : sortedCoordinateList) {
|
||||
int anchorX = anchorPoint[0];
|
||||
int anchorY = anchorPoint[1];
|
||||
int numPairs = 0;
|
||||
|
||||
for (int[] aSortedCoordinateList : sortedCoordinateList) {
|
||||
|
||||
if (numPairs >= maxPairs) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (isReferencePairing && pairedFrameTable[anchorX
|
||||
/ anchorPointsIntervalLength] >= numAnchorPointsPerInterval) {
|
||||
break;
|
||||
}
|
||||
|
||||
int targetX = aSortedCoordinateList[0];
|
||||
int targetY = aSortedCoordinateList[1];
|
||||
|
||||
if (anchorX == targetX && anchorY == targetY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// pair up the points
|
||||
int x1, y1, x2, y2; // x2 always >= x1
|
||||
if (targetX >= anchorX) {
|
||||
x2 = targetX;
|
||||
y2 = targetY;
|
||||
x1 = anchorX;
|
||||
y1 = anchorY;
|
||||
} else {
|
||||
x2 = anchorX;
|
||||
y2 = anchorY;
|
||||
x1 = targetX;
|
||||
y1 = targetY;
|
||||
}
|
||||
|
||||
// check target zone
|
||||
if ((x2 - x1) > maxTargetZoneDistance) {
|
||||
continue;
|
||||
}
|
||||
// end check target zone
|
||||
|
||||
// check filter bank zone
|
||||
if (!(y1 / bandwidthPerBank == y2 / bandwidthPerBank)) {
|
||||
continue; // same filter bank should have equal value
|
||||
}
|
||||
// end check filter bank zone
|
||||
|
||||
int pairHashcode = (x2 - x1) * numFrequencyUnits * numFrequencyUnits + y2 * numFrequencyUnits + y1;
|
||||
|
||||
// stop list applied on sample pairing only
|
||||
if (!isReferencePairing && stopPairTable.containsKey(pairHashcode)) {
|
||||
numPairs++; // no reservation
|
||||
continue; // escape this point only
|
||||
}
|
||||
// end stop list applied on sample pairing only
|
||||
|
||||
// pass all rules
|
||||
pairList.add(new int[] {pairHashcode, anchorX});
|
||||
pairedFrameTable[anchorX / anchorPointsIntervalLength]++;
|
||||
numPairs++;
|
||||
// end pair up the points
|
||||
}
|
||||
}
|
||||
|
||||
return pairList;
|
||||
}
|
||||
|
||||
private List<int[]> getSortedCoordinateList(byte[] fingerprint) {
|
||||
// each point data is 8 bytes
|
||||
// first 2 bytes is x
|
||||
// next 2 bytes is y
|
||||
// next 4 bytes is intensity
|
||||
|
||||
// get all intensities
|
||||
int numCoordinates = fingerprint.length / 8;
|
||||
int[] intensities = new int[numCoordinates];
|
||||
for (int i = 0; i < numCoordinates; i++) {
|
||||
int pointer = i * 8 + 4;
|
||||
int intensity = (fingerprint[pointer] & 0xff) << 24 | (fingerprint[pointer + 1] & 0xff) << 16
|
||||
| (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
|
||||
intensities[i] = intensity;
|
||||
}
|
||||
|
||||
QuickSortIndexPreserved quicksort = new QuickSortIndexPreserved(intensities);
|
||||
int[] sortIndexes = quicksort.getSortIndexes();
|
||||
|
||||
List<int[]> sortedCoordinateList = new LinkedList<>();
|
||||
for (int i = sortIndexes.length - 1; i >= 0; i--) {
|
||||
int pointer = sortIndexes[i] * 8;
|
||||
int x = (fingerprint[pointer] & 0xff) << 8 | (fingerprint[pointer + 1] & 0xff);
|
||||
int y = (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
|
||||
sortedCoordinateList.add(new int[] {x, y});
|
||||
}
|
||||
return sortedCoordinateList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert hashed pair to bytes
|
||||
*
|
||||
* @param pairHashcode hashed pair
|
||||
* @return byte array
|
||||
*/
|
||||
public static byte[] pairHashcodeToBytes(int pairHashcode) {
|
||||
return new byte[] {(byte) (pairHashcode >> 8), (byte) pairHashcode};
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert bytes to hased pair
|
||||
*
|
||||
* @param pairBytes
|
||||
* @return hashed pair
|
||||
*/
|
||||
public static int pairBytesToHashcode(byte[] pairBytes) {
|
||||
return (pairBytes[0] & 0xFF) << 8 | (pairBytes[1] & 0xFF);
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
public abstract class QuickSort {
|
||||
public abstract int[] getSortIndexes();
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
public class QuickSortDouble extends QuickSort {
|
||||
|
||||
private int[] indexes;
|
||||
private double[] array;
|
||||
|
||||
public QuickSortDouble(double[] array) {
|
||||
this.array = array;
|
||||
indexes = new int[array.length];
|
||||
for (int i = 0; i < indexes.length; i++) {
|
||||
indexes[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
public int[] getSortIndexes() {
|
||||
sort();
|
||||
return indexes;
|
||||
}
|
||||
|
||||
private void sort() {
|
||||
quicksort(array, indexes, 0, indexes.length - 1);
|
||||
}
|
||||
|
||||
// quicksort a[left] to a[right]
|
||||
private void quicksort(double[] a, int[] indexes, int left, int right) {
|
||||
if (right <= left)
|
||||
return;
|
||||
int i = partition(a, indexes, left, right);
|
||||
quicksort(a, indexes, left, i - 1);
|
||||
quicksort(a, indexes, i + 1, right);
|
||||
}
|
||||
|
||||
// partition a[left] to a[right], assumes left < right
|
||||
private int partition(double[] a, int[] indexes, int left, int right) {
|
||||
int i = left - 1;
|
||||
int j = right;
|
||||
while (true) {
|
||||
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
|
||||
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
|
||||
if (j == left)
|
||||
break; // don't go out-of-bounds
|
||||
}
|
||||
if (i >= j)
|
||||
break; // check if pointers cross
|
||||
swap(a, indexes, i, j); // swap two elements into place
|
||||
}
|
||||
swap(a, indexes, i, right); // swap with partition element
|
||||
return i;
|
||||
}
|
||||
|
||||
// exchange a[i] and a[j]
|
||||
private void swap(double[] a, int[] indexes, int i, int j) {
|
||||
int swap = indexes[i];
|
||||
indexes[i] = indexes[j];
|
||||
indexes[j] = swap;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
public class QuickSortIndexPreserved {
|
||||
|
||||
private QuickSort quickSort;
|
||||
|
||||
public QuickSortIndexPreserved(int[] array) {
|
||||
quickSort = new QuickSortInteger(array);
|
||||
}
|
||||
|
||||
public QuickSortIndexPreserved(double[] array) {
|
||||
quickSort = new QuickSortDouble(array);
|
||||
}
|
||||
|
||||
public QuickSortIndexPreserved(short[] array) {
|
||||
quickSort = new QuickSortShort(array);
|
||||
}
|
||||
|
||||
public int[] getSortIndexes() {
|
||||
return quickSort.getSortIndexes();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
public class QuickSortInteger extends QuickSort {
|
||||
|
||||
private int[] indexes;
|
||||
private int[] array;
|
||||
|
||||
public QuickSortInteger(int[] array) {
|
||||
this.array = array;
|
||||
indexes = new int[array.length];
|
||||
for (int i = 0; i < indexes.length; i++) {
|
||||
indexes[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
public int[] getSortIndexes() {
|
||||
sort();
|
||||
return indexes;
|
||||
}
|
||||
|
||||
private void sort() {
|
||||
quicksort(array, indexes, 0, indexes.length - 1);
|
||||
}
|
||||
|
||||
// quicksort a[left] to a[right]
|
||||
private void quicksort(int[] a, int[] indexes, int left, int right) {
|
||||
if (right <= left)
|
||||
return;
|
||||
int i = partition(a, indexes, left, right);
|
||||
quicksort(a, indexes, left, i - 1);
|
||||
quicksort(a, indexes, i + 1, right);
|
||||
}
|
||||
|
||||
// partition a[left] to a[right], assumes left < right
|
||||
private int partition(int[] a, int[] indexes, int left, int right) {
|
||||
int i = left - 1;
|
||||
int j = right;
|
||||
while (true) {
|
||||
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
|
||||
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
|
||||
if (j == left)
|
||||
break; // don't go out-of-bounds
|
||||
}
|
||||
if (i >= j)
|
||||
break; // check if pointers cross
|
||||
swap(a, indexes, i, j); // swap two elements into place
|
||||
}
|
||||
swap(a, indexes, i, right); // swap with partition element
|
||||
return i;
|
||||
}
|
||||
|
||||
// exchange a[i] and a[j]
|
||||
private void swap(int[] a, int[] indexes, int i, int j) {
|
||||
int swap = indexes[i];
|
||||
indexes[i] = indexes[j];
|
||||
indexes[j] = swap;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.fingerprint;
|
||||
|
||||
public class QuickSortShort extends QuickSort {
|
||||
|
||||
private int[] indexes;
|
||||
private short[] array;
|
||||
|
||||
public QuickSortShort(short[] array) {
|
||||
this.array = array;
|
||||
indexes = new int[array.length];
|
||||
for (int i = 0; i < indexes.length; i++) {
|
||||
indexes[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
public int[] getSortIndexes() {
|
||||
sort();
|
||||
return indexes;
|
||||
}
|
||||
|
||||
private void sort() {
|
||||
quicksort(array, indexes, 0, indexes.length - 1);
|
||||
}
|
||||
|
||||
// quicksort a[left] to a[right]
|
||||
private void quicksort(short[] a, int[] indexes, int left, int right) {
|
||||
if (right <= left)
|
||||
return;
|
||||
int i = partition(a, indexes, left, right);
|
||||
quicksort(a, indexes, left, i - 1);
|
||||
quicksort(a, indexes, i + 1, right);
|
||||
}
|
||||
|
||||
// partition a[left] to a[right], assumes left < right
|
||||
private int partition(short[] a, int[] indexes, int left, int right) {
|
||||
int i = left - 1;
|
||||
int j = right;
|
||||
while (true) {
|
||||
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
|
||||
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
|
||||
if (j == left)
|
||||
break; // don't go out-of-bounds
|
||||
}
|
||||
if (i >= j)
|
||||
break; // check if pointers cross
|
||||
swap(a, indexes, i, j); // swap two elements into place
|
||||
}
|
||||
swap(a, indexes, i, right); // swap with partition element
|
||||
return i;
|
||||
}
|
||||
|
||||
// exchange a[i] and a[j]
|
||||
private void swap(short[] a, int[] indexes, int i, int j) {
|
||||
int swap = indexes[i];
|
||||
indexes[i] = indexes[j];
|
||||
indexes[j] = swap;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.formats.input;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.formats.input.BaseInputFormat;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.audio.recordreader.WavFileRecordReader;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
*
|
||||
* Wave file input format
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class WavInputFormat extends BaseInputFormat {
|
||||
@Override
|
||||
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
|
||||
return createReader(split);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RecordReader createReader(InputSplit split) throws IOException, InterruptedException {
|
||||
RecordReader waveRecordReader = new WavFileRecordReader();
|
||||
waveRecordReader.initialize(split);
|
||||
return waveRecordReader;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.formats.output;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.exceptions.DataVecException;
|
||||
import org.datavec.api.formats.output.OutputFormat;
|
||||
import org.datavec.api.records.writer.RecordWriter;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class WaveOutputFormat implements OutputFormat {
|
||||
@Override
|
||||
public RecordWriter createWriter(Configuration conf) throws DataVecException {
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -1,139 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.processor;
|
||||
|
||||
public class ArrayRankDouble {
|
||||
|
||||
/**
|
||||
* Get the index position of maximum value the given array
|
||||
* @param array an array
|
||||
* @return index of the max value in array
|
||||
*/
|
||||
public int getMaxValueIndex(double[] array) {
|
||||
|
||||
int index = 0;
|
||||
double max = Integer.MIN_VALUE;
|
||||
|
||||
for (int i = 0; i < array.length; i++) {
|
||||
if (array[i] > max) {
|
||||
max = array[i];
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the index position of minimum value in the given array
|
||||
* @param array an array
|
||||
* @return index of the min value in array
|
||||
*/
|
||||
public int getMinValueIndex(double[] array) {
|
||||
|
||||
int index = 0;
|
||||
double min = Integer.MAX_VALUE;
|
||||
|
||||
for (int i = 0; i < array.length; i++) {
|
||||
if (array[i] < min) {
|
||||
min = array[i];
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the n-th value in the array after sorted
|
||||
* @param array an array
|
||||
* @param n position in array
|
||||
* @param ascending is ascending order or not
|
||||
* @return value at nth position of array
|
||||
*/
|
||||
public double getNthOrderedValue(double[] array, int n, boolean ascending) {
|
||||
|
||||
if (n > array.length) {
|
||||
n = array.length;
|
||||
}
|
||||
|
||||
int targetindex;
|
||||
if (ascending) {
|
||||
targetindex = n;
|
||||
} else {
|
||||
targetindex = array.length - n;
|
||||
}
|
||||
|
||||
// this value is the value of the numKey-th element
|
||||
|
||||
return getOrderedValue(array, targetindex);
|
||||
}
|
||||
|
||||
private double getOrderedValue(double[] array, int index) {
|
||||
locate(array, 0, array.length - 1, index);
|
||||
return array[index];
|
||||
}
|
||||
|
||||
// sort the partitions by quick sort, and locate the target index
|
||||
private void locate(double[] array, int left, int right, int index) {
|
||||
|
||||
int mid = (left + right) / 2;
|
||||
// System.out.println(left+" to "+right+" ("+mid+")");
|
||||
|
||||
if (right == left) {
|
||||
// System.out.println("* "+array[targetIndex]);
|
||||
// result=array[targetIndex];
|
||||
return;
|
||||
}
|
||||
|
||||
if (left < right) {
|
||||
double s = array[mid];
|
||||
int i = left - 1;
|
||||
int j = right + 1;
|
||||
|
||||
while (true) {
|
||||
while (array[++i] < s);
|
||||
while (array[--j] > s);
|
||||
if (i >= j)
|
||||
break;
|
||||
swap(array, i, j);
|
||||
}
|
||||
|
||||
// System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
|
||||
|
||||
if (i > index) {
|
||||
// the target index in the left partition
|
||||
// System.out.println("left partition");
|
||||
locate(array, left, i - 1, index);
|
||||
} else {
|
||||
// the target index in the right partition
|
||||
// System.out.println("right partition");
|
||||
locate(array, j + 1, right, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void swap(double[] array, int i, int j) {
|
||||
double t = array[i];
|
||||
array[i] = array[j];
|
||||
array[j] = t;
|
||||
}
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.processor;
|
||||
|
||||
public interface IntensityProcessor {
|
||||
|
||||
public void execute();
|
||||
|
||||
public double[][] getIntensities();
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.processor;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
public class ProcessorChain {
|
||||
|
||||
private double[][] intensities;
|
||||
List<IntensityProcessor> processorList = new LinkedList<IntensityProcessor>();
|
||||
|
||||
public ProcessorChain(double[][] intensities) {
|
||||
this.intensities = intensities;
|
||||
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, 1);
|
||||
processorList.add(robustProcessor);
|
||||
process();
|
||||
}
|
||||
|
||||
private void process() {
|
||||
for (IntensityProcessor processor : processorList) {
|
||||
processor.execute();
|
||||
intensities = processor.getIntensities();
|
||||
}
|
||||
}
|
||||
|
||||
public double[][] getIntensities() {
|
||||
return intensities;
|
||||
}
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.processor;
|
||||
|
||||
|
||||
public class RobustIntensityProcessor implements IntensityProcessor {
|
||||
|
||||
private double[][] intensities;
|
||||
private int numPointsPerFrame;
|
||||
|
||||
public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) {
|
||||
this.intensities = intensities;
|
||||
this.numPointsPerFrame = numPointsPerFrame;
|
||||
}
|
||||
|
||||
public void execute() {
|
||||
|
||||
int numX = intensities.length;
|
||||
int numY = intensities[0].length;
|
||||
double[][] processedIntensities = new double[numX][numY];
|
||||
|
||||
for (int i = 0; i < numX; i++) {
|
||||
double[] tmpArray = new double[numY];
|
||||
System.arraycopy(intensities[i], 0, tmpArray, 0, numY);
|
||||
|
||||
// pass value is the last some elements in sorted array
|
||||
ArrayRankDouble arrayRankDouble = new ArrayRankDouble();
|
||||
double passValue = arrayRankDouble.getNthOrderedValue(tmpArray, numPointsPerFrame, false);
|
||||
|
||||
// only passed elements will be assigned a value
|
||||
for (int j = 0; j < numY; j++) {
|
||||
if (intensities[i][j] >= passValue) {
|
||||
processedIntensities[i][j] = intensities[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
intensities = processedIntensities;
|
||||
}
|
||||
|
||||
public double[][] getIntensities() {
|
||||
return intensities;
|
||||
}
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.processor;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class TopManyPointsProcessorChain {
|
||||
|
||||
private double[][] intensities;
|
||||
List<IntensityProcessor> processorList = new LinkedList<>();
|
||||
|
||||
public TopManyPointsProcessorChain(double[][] intensities, int numPoints) {
|
||||
this.intensities = intensities;
|
||||
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, numPoints);
|
||||
processorList.add(robustProcessor);
|
||||
process();
|
||||
}
|
||||
|
||||
private void process() {
|
||||
for (IntensityProcessor processor : processorList) {
|
||||
processor.execute();
|
||||
intensities = processor.getIntensities();
|
||||
}
|
||||
}
|
||||
|
||||
public double[][] getIntensities() {
|
||||
return intensities;
|
||||
}
|
||||
}
|
|
@ -1,121 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.properties;
|
||||
|
||||
public class FingerprintProperties {
|
||||
|
||||
protected static FingerprintProperties instance = null;
|
||||
|
||||
private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint
|
||||
private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size
|
||||
private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ... 32
|
||||
private int numFilterBanks = 4;
|
||||
|
||||
private int upperBoundedFrequency = 1500; // low pass
|
||||
private int lowerBoundedFrequency = 400; // high pass
|
||||
private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps)
|
||||
private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps
|
||||
private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually
|
||||
|
||||
private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs
|
||||
private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip
|
||||
private int numAnchorPointsPerInterval = 10;
|
||||
private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second)
|
||||
private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second)
|
||||
|
||||
private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units
|
||||
|
||||
public static FingerprintProperties getInstance() {
|
||||
if (instance == null) {
|
||||
synchronized (FingerprintProperties.class) {
|
||||
if (instance == null) {
|
||||
instance = new FingerprintProperties();
|
||||
}
|
||||
}
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
public int getNumRobustPointsPerFrame() {
|
||||
return numRobustPointsPerFrame;
|
||||
}
|
||||
|
||||
public int getSampleSizePerFrame() {
|
||||
return sampleSizePerFrame;
|
||||
}
|
||||
|
||||
public int getOverlapFactor() {
|
||||
return overlapFactor;
|
||||
}
|
||||
|
||||
public int getNumFilterBanks() {
|
||||
return numFilterBanks;
|
||||
}
|
||||
|
||||
public int getUpperBoundedFrequency() {
|
||||
return upperBoundedFrequency;
|
||||
}
|
||||
|
||||
public int getLowerBoundedFrequency() {
|
||||
return lowerBoundedFrequency;
|
||||
}
|
||||
|
||||
public int getFps() {
|
||||
return fps;
|
||||
}
|
||||
|
||||
public int getRefMaxActivePairs() {
|
||||
return refMaxActivePairs;
|
||||
}
|
||||
|
||||
public int getSampleMaxActivePairs() {
|
||||
return sampleMaxActivePairs;
|
||||
}
|
||||
|
||||
public int getNumAnchorPointsPerInterval() {
|
||||
return numAnchorPointsPerInterval;
|
||||
}
|
||||
|
||||
public int getAnchorPointsIntervalLength() {
|
||||
return anchorPointsIntervalLength;
|
||||
}
|
||||
|
||||
public int getMaxTargetZoneDistance() {
|
||||
return maxTargetZoneDistance;
|
||||
}
|
||||
|
||||
public int getNumFrequencyUnits() {
|
||||
return numFrequencyUnits;
|
||||
}
|
||||
|
||||
public int getMaxPossiblePairHashcode() {
|
||||
return maxTargetZoneDistance * numFrequencyUnits * numFrequencyUnits + numFrequencyUnits * numFrequencyUnits
|
||||
+ numFrequencyUnits;
|
||||
}
|
||||
|
||||
public int getSampleRate() {
|
||||
return sampleRate;
|
||||
}
|
||||
|
||||
public int getNumFramesInOneSecond() {
|
||||
return numFramesInOneSecond;
|
||||
}
|
||||
}
|
|
@ -1,225 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.recordreader;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.records.Record;
|
||||
import org.datavec.api.records.metadata.RecordMetaData;
|
||||
import org.datavec.api.records.reader.BaseRecordReader;
|
||||
import org.datavec.api.split.BaseInputSplit;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.split.InputStreamInputSplit;
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import java.io.DataInputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URI;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Base audio file loader
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public abstract class BaseAudioRecordReader extends BaseRecordReader {
|
||||
private Iterator<File> iter;
|
||||
private List<Writable> record;
|
||||
private boolean hitImage = false;
|
||||
private boolean appendLabel = false;
|
||||
private List<String> labels = new ArrayList<>();
|
||||
private Configuration conf;
|
||||
protected InputSplit inputSplit;
|
||||
|
||||
public BaseAudioRecordReader() {}
|
||||
|
||||
public BaseAudioRecordReader(boolean appendLabel, List<String> labels) {
|
||||
this.appendLabel = appendLabel;
|
||||
this.labels = labels;
|
||||
}
|
||||
|
||||
public BaseAudioRecordReader(List<String> labels) {
|
||||
this.labels = labels;
|
||||
}
|
||||
|
||||
public BaseAudioRecordReader(boolean appendLabel) {
|
||||
this.appendLabel = appendLabel;
|
||||
}
|
||||
|
||||
protected abstract List<Writable> loadData(File file, InputStream inputStream) throws IOException;
|
||||
|
||||
@Override
|
||||
public void initialize(InputSplit split) throws IOException, InterruptedException {
|
||||
inputSplit = split;
|
||||
if (split instanceof BaseInputSplit) {
|
||||
URI[] locations = split.locations();
|
||||
if (locations != null && locations.length >= 1) {
|
||||
if (locations.length > 1) {
|
||||
List<File> allFiles = new ArrayList<>();
|
||||
for (URI location : locations) {
|
||||
File iter = new File(location);
|
||||
if (iter.isDirectory()) {
|
||||
Iterator<File> allFiles2 = FileUtils.iterateFiles(iter, null, true);
|
||||
while (allFiles2.hasNext())
|
||||
allFiles.add(allFiles2.next());
|
||||
}
|
||||
|
||||
else
|
||||
allFiles.add(iter);
|
||||
}
|
||||
|
||||
iter = allFiles.iterator();
|
||||
} else {
|
||||
File curr = new File(locations[0]);
|
||||
if (curr.isDirectory())
|
||||
iter = FileUtils.iterateFiles(curr, null, true);
|
||||
else
|
||||
iter = Collections.singletonList(curr).iterator();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
else if (split instanceof InputStreamInputSplit) {
|
||||
record = new ArrayList<>();
|
||||
InputStreamInputSplit split2 = (InputStreamInputSplit) split;
|
||||
InputStream is = split2.getIs();
|
||||
URI[] locations = split2.locations();
|
||||
if (appendLabel) {
|
||||
Path path = Paths.get(locations[0]);
|
||||
String parent = path.getParent().toString();
|
||||
record.add(new DoubleWritable(labels.indexOf(parent)));
|
||||
}
|
||||
|
||||
is.close();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
||||
this.conf = conf;
|
||||
this.appendLabel = conf.getBoolean(APPEND_LABEL, false);
|
||||
this.labels = new ArrayList<>(conf.getStringCollection(LABELS));
|
||||
initialize(split);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Writable> next() {
|
||||
if (iter != null) {
|
||||
File next = iter.next();
|
||||
invokeListeners(next);
|
||||
try {
|
||||
return loadData(next, null);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else if (record != null) {
|
||||
hitImage = true;
|
||||
return record;
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
if (iter != null) {
|
||||
return iter.hasNext();
|
||||
} else if (record != null) {
|
||||
return !hitImage;
|
||||
}
|
||||
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setConf(Configuration conf) {
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Configuration getConf() {
|
||||
return conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getLabels() {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
if (inputSplit == null)
|
||||
throw new UnsupportedOperationException("Cannot reset without first initializing");
|
||||
try {
|
||||
initialize(inputSplit);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error during LineRecordReader reset", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean resetSupported(){
|
||||
if(inputSplit == null){
|
||||
return false;
|
||||
}
|
||||
return inputSplit.resetSupported();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
|
||||
invokeListeners(uri);
|
||||
try {
|
||||
return loadData(null, dataInputStream);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Record nextRecord() {
|
||||
return new org.datavec.api.records.impl.Record(next(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
|
||||
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
|
||||
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.recordreader;
|
||||
|
||||
import org.bytedeco.javacv.FFmpegFrameGrabber;
|
||||
import org.bytedeco.javacv.Frame;
|
||||
import org.datavec.api.writable.FloatWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.bytedeco.ffmpeg.global.avutil.AV_SAMPLE_FMT_FLT;
|
||||
|
||||
/**
|
||||
* Native audio file loader using FFmpeg.
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class NativeAudioRecordReader extends BaseAudioRecordReader {
|
||||
|
||||
public NativeAudioRecordReader() {}
|
||||
|
||||
public NativeAudioRecordReader(boolean appendLabel, List<String> labels) {
|
||||
super(appendLabel, labels);
|
||||
}
|
||||
|
||||
public NativeAudioRecordReader(List<String> labels) {
|
||||
super(labels);
|
||||
}
|
||||
|
||||
public NativeAudioRecordReader(boolean appendLabel) {
|
||||
super(appendLabel);
|
||||
}
|
||||
|
||||
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
|
||||
List<Writable> ret = new ArrayList<>();
|
||||
try (FFmpegFrameGrabber grabber = inputStream != null ? new FFmpegFrameGrabber(inputStream)
|
||||
: new FFmpegFrameGrabber(file.getAbsolutePath())) {
|
||||
grabber.setSampleFormat(AV_SAMPLE_FMT_FLT);
|
||||
grabber.start();
|
||||
Frame frame;
|
||||
while ((frame = grabber.grab()) != null) {
|
||||
while (frame.samples != null && frame.samples[0].hasRemaining()) {
|
||||
for (int i = 0; i < frame.samples.length; i++) {
|
||||
ret.add(new FloatWritable(((FloatBuffer) frame.samples[i]).get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio.recordreader;
|
||||
|
||||
import org.datavec.api.util.RecordUtils;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.audio.Wave;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Wav file loader
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class WavFileRecordReader extends BaseAudioRecordReader {
|
||||
|
||||
public WavFileRecordReader() {}
|
||||
|
||||
public WavFileRecordReader(boolean appendLabel, List<String> labels) {
|
||||
super(appendLabel, labels);
|
||||
}
|
||||
|
||||
public WavFileRecordReader(List<String> labels) {
|
||||
super(labels);
|
||||
}
|
||||
|
||||
public WavFileRecordReader(boolean appendLabel) {
|
||||
super(appendLabel);
|
||||
}
|
||||
|
||||
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
|
||||
Wave wave = inputStream != null ? new Wave(inputStream) : new Wave(file.getAbsolutePath());
|
||||
return RecordUtils.toRecord(wave.getNormalizedAmplitudes());
|
||||
}
|
||||
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.datavec.audio;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 60000;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Set<Class<?>> getExclusions() {
|
||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getPackageName() {
|
||||
return "org.datavec.audio";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> getBaseClass() {
|
||||
return BaseND4JTest.class;
|
||||
}
|
||||
}
|
|
@ -1,68 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio;
|
||||
|
||||
import org.bytedeco.javacv.FFmpegFrameRecorder;
|
||||
import org.bytedeco.javacv.Frame;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.audio.recordreader.NativeAudioRecordReader;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.ShortBuffer;
|
||||
import java.util.List;
|
||||
|
||||
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_VORBIS;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* @author saudet
|
||||
*/
|
||||
public class AudioReaderTest extends BaseND4JTest {
|
||||
@Ignore
|
||||
@Test
|
||||
public void testNativeAudioReader() throws Exception {
|
||||
File tempFile = File.createTempFile("testNativeAudioReader", ".ogg");
|
||||
FFmpegFrameRecorder recorder = new FFmpegFrameRecorder(tempFile, 2);
|
||||
recorder.setAudioCodec(AV_CODEC_ID_VORBIS);
|
||||
recorder.setSampleRate(44100);
|
||||
recorder.start();
|
||||
Frame audioFrame = new Frame();
|
||||
ShortBuffer audioBuffer = ShortBuffer.allocate(64 * 1024);
|
||||
audioFrame.sampleRate = 44100;
|
||||
audioFrame.audioChannels = 2;
|
||||
audioFrame.samples = new ShortBuffer[] {audioBuffer};
|
||||
recorder.record(audioFrame);
|
||||
recorder.stop();
|
||||
recorder.release();
|
||||
|
||||
RecordReader reader = new NativeAudioRecordReader();
|
||||
reader.initialize(new FileSplit(tempFile));
|
||||
assertTrue(reader.hasNext());
|
||||
List<Writable> record = reader.next();
|
||||
assertEquals(audioBuffer.limit(), record.size());
|
||||
}
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.audio;
|
||||
|
||||
import org.datavec.audio.dsp.FastFourierTransform;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
public class TestFastFourierTransform extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testFastFourierTransformComplex() {
|
||||
FastFourierTransform fft = new FastFourierTransform();
|
||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
|
||||
double[] frequencies = fft.getMagnitudes(amplitudes);
|
||||
|
||||
Assert.assertEquals(2, frequencies.length);
|
||||
Assert.assertArrayEquals(new double[] {21.335, 18.513}, frequencies, 0.005);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFastFourierTransformComplexLong() {
|
||||
FastFourierTransform fft = new FastFourierTransform();
|
||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
|
||||
double[] frequencies = fft.getMagnitudes(amplitudes, true);
|
||||
|
||||
Assert.assertEquals(4, frequencies.length);
|
||||
Assert.assertArrayEquals(new double[] {21.335, 18.5132, 14.927, 7.527}, frequencies, 0.005);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFastFourierTransformReal() {
|
||||
FastFourierTransform fft = new FastFourierTransform();
|
||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
|
||||
double[] frequencies = fft.getMagnitudes(amplitudes, false);
|
||||
|
||||
Assert.assertEquals(4, frequencies.length);
|
||||
Assert.assertArrayEquals(new double[] {28.8, 2.107, 14.927, 19.874}, frequencies, 0.005);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFastFourierTransformRealOddSize() {
|
||||
FastFourierTransform fft = new FastFourierTransform();
|
||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5};
|
||||
double[] frequencies = fft.getMagnitudes(amplitudes, false);
|
||||
|
||||
Assert.assertEquals(3, frequencies.length);
|
||||
Assert.assertArrayEquals(new double[] {24.2, 3.861, 16.876}, frequencies, 0.005);
|
||||
}
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>datavec-data-codec</artifactId>
|
||||
|
||||
<name>datavec-data-codec</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jcodec</groupId>
|
||||
<artifactId>jcodec</artifactId>
|
||||
<version>0.1.5</version>
|
||||
</dependency>
|
||||
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
|
||||
<!--
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>ffmpeg-platform</artifactId>
|
||||
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
|
||||
</dependency>
|
||||
-->
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,41 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.codec.format.input;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.formats.input.BaseInputFormat;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.codec.reader.CodecRecordReader;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class CodecInputFormat extends BaseInputFormat {
|
||||
@Override
|
||||
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
|
||||
RecordReader reader = new CodecRecordReader();
|
||||
reader.initialize(conf, split);
|
||||
return reader;
|
||||
}
|
||||
}
|
|
@ -1,144 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.codec.reader;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.records.SequenceRecord;
|
||||
import org.datavec.api.records.metadata.RecordMetaData;
|
||||
import org.datavec.api.records.metadata.RecordMetaDataURI;
|
||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||
import org.datavec.api.records.reader.impl.FileRecordReader;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import java.io.DataInputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class BaseCodecRecordReader extends FileRecordReader implements SequenceRecordReader {
|
||||
protected int startFrame = 0;
|
||||
protected int numFrames = -1;
|
||||
protected int totalFrames = -1;
|
||||
protected double framesPerSecond = -1;
|
||||
protected double videoLength = -1;
|
||||
protected int rows = 28, cols = 28;
|
||||
protected boolean ravel = false;
|
||||
|
||||
public final static String NAME_SPACE = "org.datavec.codec.reader";
|
||||
public final static String ROWS = NAME_SPACE + ".rows";
|
||||
public final static String COLUMNS = NAME_SPACE + ".columns";
|
||||
public final static String START_FRAME = NAME_SPACE + ".startframe";
|
||||
public final static String TOTAL_FRAMES = NAME_SPACE + ".frames";
|
||||
public final static String TIME_SLICE = NAME_SPACE + ".time";
|
||||
public final static String RAVEL = NAME_SPACE + ".ravel";
|
||||
public final static String VIDEO_DURATION = NAME_SPACE + ".duration";
|
||||
|
||||
|
||||
@Override
|
||||
public List<List<Writable>> sequenceRecord() {
|
||||
URI next = locationsIterator.next();
|
||||
|
||||
try (InputStream s = streamCreatorFn.apply(next)){
|
||||
return loadData(null, s);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException {
|
||||
return loadData(null, dataInputStream);
|
||||
}
|
||||
|
||||
protected abstract List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException;
|
||||
|
||||
|
||||
@Override
|
||||
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
||||
setConf(conf);
|
||||
initialize(split);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Writable> next() {
|
||||
throw new UnsupportedOperationException("next() not supported for CodecRecordReader (use: sequenceRecord)");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
|
||||
throw new UnsupportedOperationException("record(URI,DataInputStream) not supported for CodecRecordReader");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setConf(Configuration conf) {
|
||||
super.setConf(conf);
|
||||
startFrame = conf.getInt(START_FRAME, 0);
|
||||
numFrames = conf.getInt(TOTAL_FRAMES, -1);
|
||||
rows = conf.getInt(ROWS, 28);
|
||||
cols = conf.getInt(COLUMNS, 28);
|
||||
framesPerSecond = conf.getFloat(TIME_SLICE, -1);
|
||||
videoLength = conf.getFloat(VIDEO_DURATION, -1);
|
||||
ravel = conf.getBoolean(RAVEL, false);
|
||||
totalFrames = conf.getInt(TOTAL_FRAMES, -1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Configuration getConf() {
|
||||
return super.getConf();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SequenceRecord nextSequence() {
|
||||
URI next = locationsIterator.next();
|
||||
|
||||
List<List<Writable>> list;
|
||||
try (InputStream s = streamCreatorFn.apply(next)){
|
||||
list = loadData(null, s);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new org.datavec.api.records.impl.SequenceRecord(list,
|
||||
new RecordMetaDataURI(next, CodecRecordReader.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException {
|
||||
return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
|
||||
List<SequenceRecord> out = new ArrayList<>();
|
||||
for (RecordMetaData meta : recordMetaDatas) {
|
||||
try (InputStream s = streamCreatorFn.apply(meta.getURI())){
|
||||
List<List<Writable>> list = loadData(null, s);
|
||||
out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta));
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
}
|
|
@ -1,138 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.codec.reader;
|
||||
|
||||
import org.apache.commons.compress.utils.IOUtils;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.image.loader.ImageLoader;
|
||||
import org.jcodec.api.FrameGrab;
|
||||
import org.jcodec.api.JCodecException;
|
||||
import org.jcodec.common.ByteBufferSeekableByteChannel;
|
||||
import org.jcodec.common.NIOUtils;
|
||||
import org.jcodec.common.SeekableByteChannel;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class CodecRecordReader extends BaseCodecRecordReader {
|
||||
|
||||
private ImageLoader imageLoader;
|
||||
|
||||
@Override
|
||||
public void setConf(Configuration conf) {
|
||||
super.setConf(conf);
|
||||
imageLoader = new ImageLoader(rows, cols);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
|
||||
SeekableByteChannel seekableByteChannel;
|
||||
if (inputStream != null) {
|
||||
//Reading video from DataInputStream: Need data from this stream in a SeekableByteChannel
|
||||
//Approach used here: load entire video into memory -> ByteBufferSeekableByteChanel
|
||||
byte[] data = IOUtils.toByteArray(inputStream);
|
||||
ByteBuffer bb = ByteBuffer.wrap(data);
|
||||
seekableByteChannel = new FixedByteBufferSeekableByteChannel(bb);
|
||||
} else {
|
||||
seekableByteChannel = NIOUtils.readableFileChannel(file);
|
||||
}
|
||||
|
||||
List<List<Writable>> record = new ArrayList<>();
|
||||
|
||||
if (numFrames >= 1) {
|
||||
FrameGrab fg;
|
||||
try {
|
||||
fg = new FrameGrab(seekableByteChannel);
|
||||
if (startFrame != 0)
|
||||
fg.seekToFramePrecise(startFrame);
|
||||
} catch (JCodecException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
for (int i = startFrame; i < startFrame + numFrames; i++) {
|
||||
try {
|
||||
BufferedImage grab = fg.getFrame();
|
||||
if (ravel)
|
||||
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
|
||||
else
|
||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (framesPerSecond < 1)
|
||||
throw new IllegalStateException("No frames or frame time intervals specified");
|
||||
|
||||
|
||||
else {
|
||||
for (double i = 0; i < videoLength; i += framesPerSecond) {
|
||||
try {
|
||||
BufferedImage grab = FrameGrab.getFrame(seekableByteChannel, i);
|
||||
if (ravel)
|
||||
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
|
||||
else
|
||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return record;
|
||||
}
|
||||
|
||||
/** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */
|
||||
private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel {
|
||||
private ByteBuffer backing;
|
||||
|
||||
public FixedByteBufferSeekableByteChannel(ByteBuffer backing) {
|
||||
super(backing);
|
||||
try {
|
||||
Field f = this.getClass().getSuperclass().getDeclaredField("maxPos");
|
||||
f.setAccessible(true);
|
||||
f.set(this, backing.limit());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
this.backing = backing;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int read(ByteBuffer dst) throws IOException {
|
||||
if (!backing.hasRemaining())
|
||||
return -1;
|
||||
return super.read(dst);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,82 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.codec.reader;
|
||||
|
||||
import org.bytedeco.javacv.FFmpegFrameGrabber;
|
||||
import org.bytedeco.javacv.Frame;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class NativeCodecRecordReader extends BaseCodecRecordReader {
|
||||
|
||||
private OpenCVFrameConverter.ToMat converter;
|
||||
private NativeImageLoader imageLoader;
|
||||
|
||||
@Override
|
||||
public void setConf(Configuration conf) {
|
||||
super.setConf(conf);
|
||||
converter = new OpenCVFrameConverter.ToMat();
|
||||
imageLoader = new NativeImageLoader(rows, cols);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
|
||||
List<List<Writable>> record = new ArrayList<>();
|
||||
|
||||
try (FFmpegFrameGrabber fg =
|
||||
inputStream != null ? new FFmpegFrameGrabber(inputStream) : new FFmpegFrameGrabber(file)) {
|
||||
if (numFrames >= 1) {
|
||||
fg.start();
|
||||
if (startFrame != 0)
|
||||
fg.setFrameNumber(startFrame);
|
||||
|
||||
for (int i = startFrame; i < startFrame + numFrames; i++) {
|
||||
Frame grab = fg.grabImage();
|
||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
|
||||
}
|
||||
} else {
|
||||
if (framesPerSecond < 1)
|
||||
throw new IllegalStateException("No frames or frame time intervals specified");
|
||||
else {
|
||||
fg.start();
|
||||
|
||||
for (double i = 0; i < videoLength; i += framesPerSecond) {
|
||||
fg.setTimestamp(Math.round(i * 1000000L));
|
||||
Frame grab = fg.grabImage();
|
||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return record;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.datavec.codec.reader;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
||||
|
||||
@Override
|
||||
protected Set<Class<?>> getExclusions() {
|
||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getPackageName() {
|
||||
return "org.datavec.codec.reader";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> getBaseClass() {
|
||||
return BaseND4JTest.class;
|
||||
}
|
||||
}
|
|
@ -1,212 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.codec.reader;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.records.SequenceRecord;
|
||||
import org.datavec.api.records.metadata.RecordMetaData;
|
||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.ArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.DataInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class CodecReaderTest {
|
||||
@Test
|
||||
public void testCodecReader() throws Exception {
|
||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
||||
SequenceRecordReader reader = new CodecRecordReader();
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(CodecRecordReader.RAVEL, "true");
|
||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
||||
conf.set(CodecRecordReader.ROWS, "80");
|
||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
||||
reader.initialize(new FileSplit(file));
|
||||
reader.setConf(conf);
|
||||
assertTrue(reader.hasNext());
|
||||
List<List<Writable>> record = reader.sequenceRecord();
|
||||
// System.out.println(record.size());
|
||||
|
||||
Iterator<List<Writable>> it = record.iterator();
|
||||
List<Writable> first = it.next();
|
||||
// System.out.println(first);
|
||||
|
||||
//Expected size: 80x46x3
|
||||
assertEquals(1, first.size());
|
||||
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCodecReaderMeta() throws Exception {
|
||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
||||
SequenceRecordReader reader = new CodecRecordReader();
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(CodecRecordReader.RAVEL, "true");
|
||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
||||
conf.set(CodecRecordReader.ROWS, "80");
|
||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
||||
reader.initialize(new FileSplit(file));
|
||||
reader.setConf(conf);
|
||||
assertTrue(reader.hasNext());
|
||||
List<List<Writable>> record = reader.sequenceRecord();
|
||||
assertEquals(500, record.size()); //500 frames
|
||||
|
||||
reader.reset();
|
||||
SequenceRecord seqR = reader.nextSequence();
|
||||
assertEquals(record, seqR.getSequenceRecord());
|
||||
RecordMetaData meta = seqR.getMetaData();
|
||||
// System.out.println(meta);
|
||||
assertTrue(meta.getURI().toString().endsWith(file.getName()));
|
||||
|
||||
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
|
||||
assertEquals(seqR, fromMeta);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testViaDataInputStream() throws Exception {
|
||||
|
||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
||||
SequenceRecordReader reader = new CodecRecordReader();
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(CodecRecordReader.RAVEL, "true");
|
||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
||||
conf.set(CodecRecordReader.ROWS, "80");
|
||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
||||
|
||||
Configuration conf2 = new Configuration(conf);
|
||||
|
||||
reader.initialize(new FileSplit(file));
|
||||
reader.setConf(conf);
|
||||
assertTrue(reader.hasNext());
|
||||
List<List<Writable>> expected = reader.sequenceRecord();
|
||||
|
||||
|
||||
SequenceRecordReader reader2 = new CodecRecordReader();
|
||||
reader2.setConf(conf2);
|
||||
|
||||
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
|
||||
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
|
||||
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testNativeCodecReader() throws Exception {
|
||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
||||
SequenceRecordReader reader = new NativeCodecRecordReader();
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(CodecRecordReader.RAVEL, "true");
|
||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
||||
conf.set(CodecRecordReader.ROWS, "80");
|
||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
||||
reader.initialize(new FileSplit(file));
|
||||
reader.setConf(conf);
|
||||
assertTrue(reader.hasNext());
|
||||
List<List<Writable>> record = reader.sequenceRecord();
|
||||
// System.out.println(record.size());
|
||||
|
||||
Iterator<List<Writable>> it = record.iterator();
|
||||
List<Writable> first = it.next();
|
||||
// System.out.println(first);
|
||||
|
||||
//Expected size: 80x46x3
|
||||
assertEquals(1, first.size());
|
||||
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testNativeCodecReaderMeta() throws Exception {
|
||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
||||
SequenceRecordReader reader = new NativeCodecRecordReader();
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(CodecRecordReader.RAVEL, "true");
|
||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
||||
conf.set(CodecRecordReader.ROWS, "80");
|
||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
||||
reader.initialize(new FileSplit(file));
|
||||
reader.setConf(conf);
|
||||
assertTrue(reader.hasNext());
|
||||
List<List<Writable>> record = reader.sequenceRecord();
|
||||
assertEquals(500, record.size()); //500 frames
|
||||
|
||||
reader.reset();
|
||||
SequenceRecord seqR = reader.nextSequence();
|
||||
assertEquals(record, seqR.getSequenceRecord());
|
||||
RecordMetaData meta = seqR.getMetaData();
|
||||
// System.out.println(meta);
|
||||
assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4"));
|
||||
|
||||
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
|
||||
assertEquals(seqR, fromMeta);
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testNativeViaDataInputStream() throws Exception {
|
||||
|
||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
||||
SequenceRecordReader reader = new NativeCodecRecordReader();
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(CodecRecordReader.RAVEL, "true");
|
||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
||||
conf.set(CodecRecordReader.ROWS, "80");
|
||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
||||
|
||||
Configuration conf2 = new Configuration(conf);
|
||||
|
||||
reader.initialize(new FileSplit(file));
|
||||
reader.setConf(conf);
|
||||
assertTrue(reader.hasNext());
|
||||
List<List<Writable>> expected = reader.sequenceRecord();
|
||||
|
||||
|
||||
SequenceRecordReader reader2 = new NativeCodecRecordReader();
|
||||
reader2.setConf(conf2);
|
||||
|
||||
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
|
||||
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
|
||||
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>datavec-data-nlp</artifactId>
|
||||
|
||||
<name>datavec-data-nlp</name>
|
||||
|
||||
<properties>
|
||||
<cleartk.version>2.0.0</cleartk.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-local</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.cleartk</groupId>
|
||||
<artifactId>cleartk-snowball</artifactId>
|
||||
<version>${cleartk.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.cleartk</groupId>
|
||||
<artifactId>cleartk-opennlp-tools</artifactId>
|
||||
<version>${cleartk.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,237 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.annotator;
|
||||
|
||||
import opennlp.tools.postag.POSModel;
|
||||
import opennlp.tools.postag.POSTaggerME;
|
||||
import opennlp.uima.postag.POSModelResource;
|
||||
import opennlp.uima.postag.POSModelResourceImpl;
|
||||
import opennlp.uima.util.AnnotationComboIterator;
|
||||
import opennlp.uima.util.AnnotationIteratorPair;
|
||||
import opennlp.uima.util.AnnotatorUtil;
|
||||
import opennlp.uima.util.UimaUtil;
|
||||
import org.apache.uima.UimaContext;
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
||||
import org.apache.uima.cas.CAS;
|
||||
import org.apache.uima.cas.Feature;
|
||||
import org.apache.uima.cas.Type;
|
||||
import org.apache.uima.cas.TypeSystem;
|
||||
import org.apache.uima.cas.text.AnnotationFS;
|
||||
import org.apache.uima.fit.component.CasAnnotator_ImplBase;
|
||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
||||
import org.apache.uima.fit.factory.ExternalResourceFactory;
|
||||
import org.apache.uima.resource.ResourceAccessException;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.apache.uima.util.Level;
|
||||
import org.apache.uima.util.Logger;
|
||||
import org.cleartk.token.type.Sentence;
|
||||
import org.cleartk.token.type.Token;
|
||||
import org.datavec.nlp.movingwindow.Util;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class PoStagger extends CasAnnotator_ImplBase {
|
||||
|
||||
static {
|
||||
//UIMA logging
|
||||
Util.disableLogging();
|
||||
}
|
||||
|
||||
private POSTaggerME posTagger;
|
||||
|
||||
private Type sentenceType;
|
||||
|
||||
private Type tokenType;
|
||||
|
||||
private Feature posFeature;
|
||||
|
||||
private Feature probabilityFeature;
|
||||
|
||||
private UimaContext context;
|
||||
|
||||
private Logger logger;
|
||||
|
||||
/**
|
||||
* Initializes a new instance.
|
||||
*
|
||||
* Note: Use {@link #initialize(org.apache.uima.UimaContext) } to initialize this instance. Not use the
|
||||
* constructor.
|
||||
*/
|
||||
public PoStagger() {
|
||||
// must not be implemented !
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the current instance with the given context.
|
||||
*
|
||||
* Note: Do all initialization in this method, do not use the constructor.
|
||||
*/
|
||||
@Override
|
||||
public void initialize(UimaContext context) throws ResourceInitializationException {
|
||||
|
||||
super.initialize(context);
|
||||
|
||||
this.context = context;
|
||||
|
||||
this.logger = context.getLogger();
|
||||
|
||||
if (this.logger.isLoggable(Level.INFO)) {
|
||||
this.logger.log(Level.INFO, "Initializing the OpenNLP " + "Part of Speech annotator.");
|
||||
}
|
||||
|
||||
POSModel model;
|
||||
|
||||
try {
|
||||
POSModelResource modelResource = (POSModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
|
||||
|
||||
model = modelResource.getModel();
|
||||
} catch (ResourceAccessException e) {
|
||||
throw new ResourceInitializationException(e);
|
||||
}
|
||||
|
||||
Integer beamSize = AnnotatorUtil.getOptionalIntegerParameter(context, UimaUtil.BEAM_SIZE_PARAMETER);
|
||||
|
||||
if (beamSize == null)
|
||||
beamSize = POSTaggerME.DEFAULT_BEAM_SIZE;
|
||||
|
||||
this.posTagger = new POSTaggerME(model, beamSize, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the type system.
|
||||
*/
|
||||
@Override
|
||||
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
|
||||
|
||||
// sentence type
|
||||
this.sentenceType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
|
||||
UimaUtil.SENTENCE_TYPE_PARAMETER);
|
||||
|
||||
// token type
|
||||
this.tokenType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
|
||||
UimaUtil.TOKEN_TYPE_PARAMETER);
|
||||
|
||||
// pos feature
|
||||
this.posFeature = AnnotatorUtil.getRequiredFeatureParameter(this.context, this.tokenType,
|
||||
UimaUtil.POS_FEATURE_PARAMETER, CAS.TYPE_NAME_STRING);
|
||||
|
||||
this.probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(this.context, this.tokenType,
|
||||
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs pos-tagging on the given tcas object.
|
||||
*/
|
||||
@Override
|
||||
public synchronized void process(CAS tcas) {
|
||||
|
||||
final AnnotationComboIterator comboIterator =
|
||||
new AnnotationComboIterator(tcas, this.sentenceType, this.tokenType);
|
||||
|
||||
for (AnnotationIteratorPair annotationIteratorPair : comboIterator) {
|
||||
|
||||
final List<AnnotationFS> sentenceTokenAnnotationList = new LinkedList<AnnotationFS>();
|
||||
|
||||
final List<String> sentenceTokenList = new LinkedList<String>();
|
||||
|
||||
for (AnnotationFS tokenAnnotation : annotationIteratorPair.getSubIterator()) {
|
||||
|
||||
sentenceTokenAnnotationList.add(tokenAnnotation);
|
||||
|
||||
sentenceTokenList.add(tokenAnnotation.getCoveredText());
|
||||
}
|
||||
|
||||
final List<String> posTags = this.posTagger.tag(sentenceTokenList);
|
||||
|
||||
double posProbabilities[] = null;
|
||||
|
||||
if (this.probabilityFeature != null) {
|
||||
posProbabilities = this.posTagger.probs();
|
||||
}
|
||||
|
||||
final Iterator<String> posTagIterator = posTags.iterator();
|
||||
final Iterator<AnnotationFS> sentenceTokenIterator = sentenceTokenAnnotationList.iterator();
|
||||
|
||||
int index = 0;
|
||||
while (posTagIterator.hasNext() && sentenceTokenIterator.hasNext()) {
|
||||
final String posTag = posTagIterator.next();
|
||||
final AnnotationFS tokenAnnotation = sentenceTokenIterator.next();
|
||||
|
||||
tokenAnnotation.setStringValue(this.posFeature, posTag);
|
||||
|
||||
if (posProbabilities != null) {
|
||||
tokenAnnotation.setDoubleValue(this.posFeature, posProbabilities[index]);
|
||||
}
|
||||
|
||||
index++;
|
||||
}
|
||||
|
||||
// log tokens with pos
|
||||
if (this.logger.isLoggable(Level.FINER)) {
|
||||
|
||||
final StringBuilder sentenceWithPos = new StringBuilder();
|
||||
|
||||
sentenceWithPos.append("\"");
|
||||
|
||||
for (final Iterator<AnnotationFS> it = sentenceTokenAnnotationList.iterator(); it.hasNext();) {
|
||||
final AnnotationFS token = it.next();
|
||||
sentenceWithPos.append(token.getCoveredText());
|
||||
sentenceWithPos.append('\\');
|
||||
sentenceWithPos.append(token.getStringValue(this.posFeature));
|
||||
sentenceWithPos.append(' ');
|
||||
}
|
||||
|
||||
// delete last whitespace
|
||||
if (sentenceWithPos.length() > 1) // not 0 because it contains already the " char
|
||||
sentenceWithPos.setLength(sentenceWithPos.length() - 1);
|
||||
|
||||
sentenceWithPos.append("\"");
|
||||
|
||||
this.logger.log(Level.FINER, sentenceWithPos.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Releases allocated resources.
|
||||
*/
|
||||
@Override
|
||||
public void destroy() {
|
||||
this.posTagger = null;
|
||||
}
|
||||
|
||||
|
||||
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
|
||||
String modelPath = String.format("/models/%s-pos-maxent.bin", languageCode);
|
||||
return AnalysisEngineFactory.createEngineDescription(PoStagger.class, UimaUtil.MODEL_PARAMETER,
|
||||
ExternalResourceFactory.createExternalResourceDescription(POSModelResourceImpl.class,
|
||||
PoStagger.class.getResource(modelPath).toString()),
|
||||
UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), UimaUtil.TOKEN_TYPE_PARAMETER,
|
||||
Token.class.getName(), UimaUtil.POS_FEATURE_PARAMETER, "pos");
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.annotator;
|
||||
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
||||
import org.apache.uima.jcas.JCas;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.cleartk.util.ParamUtil;
|
||||
import org.datavec.nlp.movingwindow.Util;
|
||||
|
||||
public class SentenceAnnotator extends org.cleartk.opennlp.tools.SentenceAnnotator {
|
||||
|
||||
static {
|
||||
//UIMA logging
|
||||
Util.disableLogging();
|
||||
}
|
||||
|
||||
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
|
||||
return AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.class, PARAM_SENTENCE_MODEL_PATH,
|
||||
ParamUtil.getParameterValue(PARAM_SENTENCE_MODEL_PATH, "/models/en-sent.bin"),
|
||||
PARAM_WINDOW_CLASS_NAMES, ParamUtil.getParameterValue(PARAM_WINDOW_CLASS_NAMES, null));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
|
||||
super.process(jCas);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.annotator;
|
||||
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
||||
import org.apache.uima.jcas.JCas;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.cleartk.snowball.SnowballStemmer;
|
||||
import org.cleartk.token.type.Token;
|
||||
|
||||
|
||||
public class StemmerAnnotator extends SnowballStemmer<Token> {
|
||||
|
||||
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
|
||||
return getDescription("English");
|
||||
}
|
||||
|
||||
|
||||
public static AnalysisEngineDescription getDescription(String language) throws ResourceInitializationException {
|
||||
return AnalysisEngineFactory.createEngineDescription(StemmerAnnotator.class, SnowballStemmer.PARAM_STEMMER_NAME,
|
||||
language);
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
|
||||
super.process(jCas);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public void setStem(Token token, String stem) {
|
||||
token.setStem(stem);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.annotator;
|
||||
|
||||
|
||||
import opennlp.uima.tokenize.TokenizerModelResourceImpl;
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
||||
import org.apache.uima.fit.factory.ExternalResourceFactory;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.cleartk.opennlp.tools.Tokenizer;
|
||||
import org.cleartk.token.type.Sentence;
|
||||
import org.cleartk.token.type.Token;
|
||||
import org.datavec.nlp.movingwindow.Util;
|
||||
import org.datavec.nlp.tokenization.tokenizer.ConcurrentTokenizer;
|
||||
|
||||
|
||||
/**
|
||||
* Overrides OpenNLP tokenizer to be thread safe
|
||||
*/
|
||||
public class TokenizerAnnotator extends Tokenizer {
|
||||
|
||||
static {
|
||||
//UIMA logging
|
||||
Util.disableLogging();
|
||||
}
|
||||
|
||||
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
|
||||
String modelPath = String.format("/models/%s-token.bin", languageCode);
|
||||
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
|
||||
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
|
||||
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
|
||||
ConcurrentTokenizer.class.getResource(modelPath).toString()),
|
||||
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
|
||||
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
|
||||
}
|
||||
|
||||
|
||||
|
||||
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
|
||||
String modelPath = String.format("/models/%s-token.bin", "en");
|
||||
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
|
||||
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
|
||||
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
|
||||
ConcurrentTokenizer.class.getResource(modelPath).toString()),
|
||||
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
|
||||
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.input;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.formats.input.BaseInputFormat;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.nlp.reader.TfidfRecordReader;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class TextInputFormat extends BaseInputFormat {
|
||||
@Override
|
||||
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
|
||||
RecordReader reader = new TfidfRecordReader();
|
||||
reader.initialize(conf, split);
|
||||
return reader;
|
||||
}
|
||||
}
|
|
@ -1,148 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.metadata;
|
||||
|
||||
import org.nd4j.common.primitives.Counter;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.nlp.vectorizer.TextVectorizer;
|
||||
import org.nd4j.common.util.MathUtils;
|
||||
import org.nd4j.common.util.Index;
|
||||
|
||||
public class DefaultVocabCache implements VocabCache {
|
||||
|
||||
private Counter<String> wordFrequencies = new Counter<>();
|
||||
private Counter<String> docFrequencies = new Counter<>();
|
||||
private int minWordFrequency;
|
||||
private Index vocabWords = new Index();
|
||||
private double numDocs = 0;
|
||||
|
||||
/**
|
||||
* Instantiate with a given min word frequency
|
||||
* @param minWordFrequency
|
||||
*/
|
||||
public DefaultVocabCache(int minWordFrequency) {
|
||||
this.minWordFrequency = minWordFrequency;
|
||||
}
|
||||
|
||||
/*
|
||||
* Constructor for use with initialize()
|
||||
*/
|
||||
public DefaultVocabCache() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incrementNumDocs(double by) {
|
||||
numDocs += by;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double numDocs() {
|
||||
return numDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String wordAt(int i) {
|
||||
return vocabWords.get(i).toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int wordIndex(String word) {
|
||||
return vocabWords.indexOf(word);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initialize(Configuration conf) {
|
||||
minWordFrequency = conf.getInt(TextVectorizer.MIN_WORD_FREQUENCY, 5);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double wordFrequency(String word) {
|
||||
return wordFrequencies.getCount(word);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minWordFrequency() {
|
||||
return minWordFrequency;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Index vocabWords() {
|
||||
return vocabWords;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incrementDocCount(String word) {
|
||||
incrementDocCount(word, 1.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incrementDocCount(String word, double by) {
|
||||
docFrequencies.incrementCount(word, by);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incrementCount(String word) {
|
||||
incrementCount(word, 1.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incrementCount(String word, double by) {
|
||||
wordFrequencies.incrementCount(word, by);
|
||||
if (wordFrequencies.getCount(word) >= minWordFrequency && vocabWords.indexOf(word) < 0)
|
||||
vocabWords.add(word);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double idf(String word) {
|
||||
return docFrequencies.getCount(word);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double tfidf(String word, double frequency, boolean smoothIdf) {
|
||||
double tf = tf((int) frequency);
|
||||
double docFreq = docFrequencies.getCount(word);
|
||||
|
||||
double idf = idf(numDocs, docFreq, smoothIdf);
|
||||
double tfidf = MathUtils.tfidf(tf, idf);
|
||||
return tfidf;
|
||||
}
|
||||
|
||||
public double idf(double totalDocs, double numTimesWordAppearedInADocument, boolean smooth) {
|
||||
if(smooth){
|
||||
return Math.log((1 + totalDocs) / (1 + numTimesWordAppearedInADocument)) + 1.0;
|
||||
} else {
|
||||
return Math.log(totalDocs / numTimesWordAppearedInADocument) + 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
public static double tf(int count) {
|
||||
return count;
|
||||
}
|
||||
|
||||
public int getMinWordFrequency() {
|
||||
return minWordFrequency;
|
||||
}
|
||||
|
||||
public void setMinWordFrequency(int minWordFrequency) {
|
||||
this.minWordFrequency = minWordFrequency;
|
||||
}
|
||||
}
|
|
@ -1,121 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.metadata;
|
||||
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.nd4j.common.util.Index;
|
||||
|
||||
public interface VocabCache {
|
||||
|
||||
|
||||
/**
|
||||
* Increment the number of documents
|
||||
* @param by
|
||||
*/
|
||||
void incrementNumDocs(double by);
|
||||
|
||||
/**
|
||||
* Number of documents
|
||||
* @return the number of documents
|
||||
*/
|
||||
double numDocs();
|
||||
|
||||
/**
|
||||
* Returns a word in the vocab at a particular index
|
||||
* @param i the index to get
|
||||
* @return the word at that index in the vocab
|
||||
*/
|
||||
String wordAt(int i);
|
||||
|
||||
int wordIndex(String word);
|
||||
|
||||
/**
|
||||
* Configuration for initializing
|
||||
* @param conf the configuration to initialize with
|
||||
*/
|
||||
void initialize(Configuration conf);
|
||||
|
||||
/**
|
||||
* Get the word frequency for a word
|
||||
* @param word the word to get frequency for
|
||||
* @return the frequency for a given word
|
||||
*/
|
||||
double wordFrequency(String word);
|
||||
|
||||
/**
|
||||
* The min word frequency
|
||||
* needed to be included in the vocab
|
||||
* (default 5)
|
||||
* @return the min word frequency to
|
||||
* be included in the vocab
|
||||
*/
|
||||
int minWordFrequency();
|
||||
|
||||
/**
|
||||
* All of the vocab words (ordered)
|
||||
* note that these are not all the possible tokens
|
||||
* @return the list of vocab words
|
||||
*/
|
||||
Index vocabWords();
|
||||
|
||||
|
||||
/**
|
||||
* Increment the doc count for a word by 1
|
||||
* @param word the word to increment the count for
|
||||
*/
|
||||
void incrementDocCount(String word);
|
||||
|
||||
/**
|
||||
* Increment the document count for a particular word
|
||||
* @param word the word to increment the count for
|
||||
* @param by the amount to increment by
|
||||
*/
|
||||
void incrementDocCount(String word, double by);
|
||||
|
||||
/**
|
||||
* Increment a word count by 1
|
||||
* @param word the word to increment the count for
|
||||
*/
|
||||
void incrementCount(String word);
|
||||
|
||||
/**
|
||||
* Increment count for a word
|
||||
* @param word the word to increment the count for
|
||||
* @param by the amount to increment by
|
||||
*/
|
||||
void incrementCount(String word, double by);
|
||||
|
||||
/**
|
||||
* Number of documents word has occurred in
|
||||
* @param word the word to get the idf for
|
||||
*/
|
||||
double idf(String word);
|
||||
|
||||
/**
|
||||
* Calculate the tfidf of the word given the document frequency
|
||||
* @param word the word to get frequency for
|
||||
* @param frequency the frequency
|
||||
* @return the tfidf for a word
|
||||
*/
|
||||
double tfidf(String word, double frequency, boolean smoothIdf);
|
||||
|
||||
}
|
|
@ -1,125 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.movingwindow;
|
||||
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.common.collection.MultiDimensionalMap;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ContextLabelRetriever {
|
||||
|
||||
|
||||
private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>";
|
||||
private static String END_LABEL = "</([A-Za-z]+|\\d+)>";
|
||||
|
||||
|
||||
private ContextLabelRetriever() {}
|
||||
|
||||
|
||||
/**
|
||||
* Returns a stripped sentence with the indices of words
|
||||
* with certain kinds of labels.
|
||||
*
|
||||
* @param sentence the sentence to process
|
||||
* @return a pair of a post processed sentence
|
||||
* with labels stripped and the spans of
|
||||
* the labels
|
||||
*/
|
||||
public static Pair<String, MultiDimensionalMap<Integer, Integer, String>> stringWithLabels(String sentence,
|
||||
TokenizerFactory tokenizerFactory) {
|
||||
MultiDimensionalMap<Integer, Integer, String> map = MultiDimensionalMap.newHashBackedMap();
|
||||
Tokenizer t = tokenizerFactory.create(sentence);
|
||||
List<String> currTokens = new ArrayList<>();
|
||||
String currLabel = null;
|
||||
String endLabel = null;
|
||||
List<Pair<String, List<String>>> tokensWithSameLabel = new ArrayList<>();
|
||||
while (t.hasMoreTokens()) {
|
||||
String token = t.nextToken();
|
||||
if (token.matches(BEGIN_LABEL)) {
|
||||
currLabel = token;
|
||||
|
||||
//no labels; add these as NONE and begin the new label
|
||||
if (!currTokens.isEmpty()) {
|
||||
tokensWithSameLabel.add(new Pair<>("NONE", (List<String>) new ArrayList<>(currTokens)));
|
||||
currTokens.clear();
|
||||
|
||||
}
|
||||
|
||||
} else if (token.matches(END_LABEL)) {
|
||||
if (currLabel == null)
|
||||
throw new IllegalStateException("Found an ending label with no matching begin label");
|
||||
endLabel = token;
|
||||
} else
|
||||
currTokens.add(token);
|
||||
|
||||
if (currLabel != null && endLabel != null) {
|
||||
currLabel = currLabel.replaceAll("[<>/]", "");
|
||||
endLabel = endLabel.replaceAll("[<>/]", "");
|
||||
Preconditions.checkState(!currLabel.isEmpty(), "Current label is empty!");
|
||||
Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!");
|
||||
Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s",
|
||||
currLabel, endLabel);
|
||||
|
||||
tokensWithSameLabel.add(new Pair<>(currLabel, (List<String>) new ArrayList<>(currTokens)));
|
||||
currTokens.clear();
|
||||
|
||||
|
||||
//clear out the tokens
|
||||
currLabel = null;
|
||||
endLabel = null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
//no labels; add these as NONE and begin the new label
|
||||
if (!currTokens.isEmpty()) {
|
||||
tokensWithSameLabel.add(new Pair<>("none", (List<String>) new ArrayList<>(currTokens)));
|
||||
currTokens.clear();
|
||||
|
||||
}
|
||||
|
||||
//now join the output
|
||||
StringBuilder strippedSentence = new StringBuilder();
|
||||
for (Pair<String, List<String>> tokensWithLabel : tokensWithSameLabel) {
|
||||
String joinedSentence = StringUtils.join(tokensWithLabel.getSecond(), " ");
|
||||
//spaces between separate parts of the sentence
|
||||
if (!(strippedSentence.length() < 1))
|
||||
strippedSentence.append(" ");
|
||||
strippedSentence.append(joinedSentence);
|
||||
int begin = strippedSentence.toString().indexOf(joinedSentence);
|
||||
int end = begin + joinedSentence.length();
|
||||
map.put(begin, end, tokensWithLabel.getFirst());
|
||||
}
|
||||
|
||||
|
||||
return new Pair<>(strippedSentence.toString(), map);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.movingwindow;
|
||||
|
||||
|
||||
import org.nd4j.common.primitives.Counter;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
||||
public class Util {
|
||||
|
||||
/**
|
||||
* Returns a thread safe counter
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public static Counter<String> parallelCounter() {
|
||||
return new Counter<>();
|
||||
}
|
||||
|
||||
public static boolean matchesAnyStopWord(List<String> stopWords, String word) {
|
||||
for (String s : stopWords)
|
||||
if (s.equalsIgnoreCase(word))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
public static Level disableLogging() {
|
||||
Logger logger = Logger.getLogger("org.apache.uima");
|
||||
while (logger.getLevel() == null) {
|
||||
logger = logger.getParent();
|
||||
}
|
||||
Level level = logger.getLevel();
|
||||
logger.setLevel(Level.OFF);
|
||||
return level;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,177 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.movingwindow;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class Window implements Serializable {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 6359906393699230579L;
|
||||
private List<String> words;
|
||||
private String label = "NONE";
|
||||
private boolean beginLabel;
|
||||
private boolean endLabel;
|
||||
private int median;
|
||||
private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>";
|
||||
private static String END_LABEL = "</([A-Z]+|\\d+)>";
|
||||
private int begin, end;
|
||||
|
||||
/**
|
||||
* Creates a window with a context of size 3
|
||||
* @param words a collection of strings of size 3
|
||||
*/
|
||||
public Window(Collection<String> words, int begin, int end) {
|
||||
this(words, 5, begin, end);
|
||||
|
||||
}
|
||||
|
||||
public String asTokens() {
|
||||
return StringUtils.join(words, " ");
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Initialize a window with the given size
|
||||
* @param words the words to use
|
||||
* @param windowSize the size of the window
|
||||
* @param begin the begin index for the window
|
||||
* @param end the end index for the window
|
||||
*/
|
||||
public Window(Collection<String> words, int windowSize, int begin, int end) {
|
||||
if (words == null)
|
||||
throw new IllegalArgumentException("Words must be a list of size 3");
|
||||
|
||||
this.words = new ArrayList<>(words);
|
||||
int windowSize1 = windowSize;
|
||||
this.begin = begin;
|
||||
this.end = end;
|
||||
initContext();
|
||||
}
|
||||
|
||||
|
||||
private void initContext() {
|
||||
int median = (int) Math.floor(words.size() / 2);
|
||||
List<String> begin = words.subList(0, median);
|
||||
List<String> after = words.subList(median + 1, words.size());
|
||||
|
||||
|
||||
for (String s : begin) {
|
||||
if (s.matches(BEGIN_LABEL)) {
|
||||
this.label = s.replaceAll("(<|>)", "").replace("/", "");
|
||||
beginLabel = true;
|
||||
} else if (s.matches(END_LABEL)) {
|
||||
endLabel = true;
|
||||
this.label = s.replaceAll("(<|>|/)", "").replace("/", "");
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (String s1 : after) {
|
||||
|
||||
if (s1.matches(BEGIN_LABEL)) {
|
||||
this.label = s1.replaceAll("(<|>)", "").replace("/", "");
|
||||
beginLabel = true;
|
||||
}
|
||||
|
||||
if (s1.matches(END_LABEL)) {
|
||||
endLabel = true;
|
||||
this.label = s1.replaceAll("(<|>)", "");
|
||||
|
||||
}
|
||||
}
|
||||
this.median = median;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return words.toString();
|
||||
}
|
||||
|
||||
public List<String> getWords() {
|
||||
return words;
|
||||
}
|
||||
|
||||
public void setWords(List<String> words) {
|
||||
this.words = words;
|
||||
}
|
||||
|
||||
public String getWord(int i) {
|
||||
return words.get(i);
|
||||
}
|
||||
|
||||
public String getFocusWord() {
|
||||
return words.get(median);
|
||||
}
|
||||
|
||||
public boolean isBeginLabel() {
|
||||
return !label.equals("NONE") && beginLabel;
|
||||
}
|
||||
|
||||
public boolean isEndLabel() {
|
||||
return !label.equals("NONE") && endLabel;
|
||||
}
|
||||
|
||||
public String getLabel() {
|
||||
return label.replace("/", "");
|
||||
}
|
||||
|
||||
public int getWindowSize() {
|
||||
return words.size();
|
||||
}
|
||||
|
||||
public int getMedian() {
|
||||
return median;
|
||||
}
|
||||
|
||||
public void setLabel(String label) {
|
||||
this.label = label;
|
||||
}
|
||||
|
||||
public int getBegin() {
|
||||
return begin;
|
||||
}
|
||||
|
||||
public void setBegin(int begin) {
|
||||
this.begin = begin;
|
||||
}
|
||||
|
||||
public int getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public void setEnd(int end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,188 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.movingwindow;
|
||||
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.StringTokenizer;
|
||||
|
||||
public class Windows {
|
||||
|
||||
|
||||
/**
|
||||
* Constructs a list of window of size windowSize.
|
||||
* Note that padding for each window is created as well.
|
||||
* @param words the words to tokenize and construct windows from
|
||||
* @param windowSize the window size to generate
|
||||
* @return the list of windows for the tokenized string
|
||||
*/
|
||||
public static List<Window> windows(InputStream words, int windowSize) {
|
||||
Tokenizer tokenizer = new DefaultStreamTokenizer(words);
|
||||
List<String> list = new ArrayList<>();
|
||||
while (tokenizer.hasMoreTokens())
|
||||
list.add(tokenizer.nextToken());
|
||||
return windows(list, windowSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a list of window of size windowSize.
|
||||
* Note that padding for each window is created as well.
|
||||
* @param words the words to tokenize and construct windows from
|
||||
* @param tokenizerFactory tokenizer factory to use
|
||||
* @param windowSize the window size to generate
|
||||
* @return the list of windows for the tokenized string
|
||||
*/
|
||||
public static List<Window> windows(InputStream words, TokenizerFactory tokenizerFactory, int windowSize) {
|
||||
Tokenizer tokenizer = tokenizerFactory.create(words);
|
||||
List<String> list = new ArrayList<>();
|
||||
while (tokenizer.hasMoreTokens())
|
||||
list.add(tokenizer.nextToken());
|
||||
|
||||
if (list.isEmpty())
|
||||
throw new IllegalStateException("No tokens found for windows");
|
||||
|
||||
return windows(list, windowSize);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Constructs a list of window of size windowSize.
|
||||
* Note that padding for each window is created as well.
|
||||
* @param words the words to tokenize and construct windows from
|
||||
* @param windowSize the window size to generate
|
||||
* @return the list of windows for the tokenized string
|
||||
*/
|
||||
public static List<Window> windows(String words, int windowSize) {
|
||||
StringTokenizer tokenizer = new StringTokenizer(words);
|
||||
List<String> list = new ArrayList<String>();
|
||||
while (tokenizer.hasMoreTokens())
|
||||
list.add(tokenizer.nextToken());
|
||||
return windows(list, windowSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a list of window of size windowSize.
|
||||
* Note that padding for each window is created as well.
|
||||
* @param words the words to tokenize and construct windows from
|
||||
* @param tokenizerFactory tokenizer factory to use
|
||||
* @param windowSize the window size to generate
|
||||
* @return the list of windows for the tokenized string
|
||||
*/
|
||||
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory, int windowSize) {
|
||||
Tokenizer tokenizer = tokenizerFactory.create(words);
|
||||
List<String> list = new ArrayList<>();
|
||||
while (tokenizer.hasMoreTokens())
|
||||
list.add(tokenizer.nextToken());
|
||||
|
||||
if (list.isEmpty())
|
||||
throw new IllegalStateException("No tokens found for windows");
|
||||
|
||||
return windows(list, windowSize);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Constructs a list of window of size windowSize.
|
||||
* Note that padding for each window is created as well.
|
||||
* @param words the words to tokenize and construct windows from
|
||||
* @return the list of windows for the tokenized string
|
||||
*/
|
||||
public static List<Window> windows(String words) {
|
||||
StringTokenizer tokenizer = new StringTokenizer(words);
|
||||
List<String> list = new ArrayList<String>();
|
||||
while (tokenizer.hasMoreTokens())
|
||||
list.add(tokenizer.nextToken());
|
||||
return windows(list, 5);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a list of window of size windowSize.
|
||||
* Note that padding for each window is created as well.
|
||||
* @param words the words to tokenize and construct windows from
|
||||
* @param tokenizerFactory tokenizer factory to use
|
||||
* @return the list of windows for the tokenized string
|
||||
*/
|
||||
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory) {
|
||||
Tokenizer tokenizer = tokenizerFactory.create(words);
|
||||
List<String> list = new ArrayList<>();
|
||||
while (tokenizer.hasMoreTokens())
|
||||
list.add(tokenizer.nextToken());
|
||||
return windows(list, 5);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Creates a sliding window from text
|
||||
* @param windowSize the window size to use
|
||||
* @param wordPos the position of the word to center
|
||||
* @param sentence the sentence to createComplex a window for
|
||||
* @return a window based on the given sentence
|
||||
*/
|
||||
public static Window windowForWordInPosition(int windowSize, int wordPos, List<String> sentence) {
|
||||
List<String> window = new ArrayList<>();
|
||||
List<String> onlyTokens = new ArrayList<>();
|
||||
int contextSize = (int) Math.floor((windowSize - 1) / 2);
|
||||
|
||||
for (int i = wordPos - contextSize; i <= wordPos + contextSize; i++) {
|
||||
if (i < 0)
|
||||
window.add("<s>");
|
||||
else if (i >= sentence.size())
|
||||
window.add("</s>");
|
||||
else {
|
||||
onlyTokens.add(sentence.get(i));
|
||||
window.add(sentence.get(i));
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
String wholeSentence = StringUtils.join(sentence);
|
||||
String window2 = StringUtils.join(onlyTokens);
|
||||
int begin = wholeSentence.indexOf(window2);
|
||||
int end = begin + window2.length();
|
||||
return new Window(window, begin, end);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Constructs a list of window of size windowSize
|
||||
* @param words the words to construct windows from
|
||||
* @return the list of windows for the tokenized string
|
||||
*/
|
||||
public static List<Window> windows(List<String> words, int windowSize) {
|
||||
|
||||
List<Window> ret = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < words.size(); i++)
|
||||
ret.add(windowForWordInPosition(windowSize, i, words));
|
||||
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,189 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.reader;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.records.Record;
|
||||
import org.datavec.api.records.metadata.RecordMetaData;
|
||||
import org.datavec.api.records.metadata.RecordMetaDataURI;
|
||||
import org.datavec.api.records.reader.impl.FileRecordReader;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.vector.Vectorizer;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.nlp.vectorizer.TfidfVectorizer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
public class TfidfRecordReader extends FileRecordReader {
|
||||
private TfidfVectorizer tfidfVectorizer;
|
||||
private List<Record> records = new ArrayList<>();
|
||||
private Iterator<Record> recordIter;
|
||||
private int numFeatures;
|
||||
private boolean initialized = false;
|
||||
|
||||
|
||||
@Override
|
||||
public void initialize(InputSplit split) throws IOException, InterruptedException {
|
||||
initialize(new Configuration(), split);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
||||
super.initialize(conf, split);
|
||||
//train a new one since it hasn't been specified
|
||||
if (tfidfVectorizer == null) {
|
||||
tfidfVectorizer = new TfidfVectorizer();
|
||||
tfidfVectorizer.initialize(conf);
|
||||
|
||||
//clear out old strings
|
||||
records.clear();
|
||||
|
||||
INDArray ret = tfidfVectorizer.fitTransform(this, new Vectorizer.RecordCallBack() {
|
||||
@Override
|
||||
public void onRecord(Record fullRecord) {
|
||||
records.add(fullRecord);
|
||||
}
|
||||
});
|
||||
|
||||
//cache the number of features used for each document
|
||||
numFeatures = ret.columns();
|
||||
recordIter = records.iterator();
|
||||
} else {
|
||||
records = new ArrayList<>();
|
||||
|
||||
//the record reader has 2 phases, we are skipping the
|
||||
//document frequency phase and just using the super() to get the file contents
|
||||
//and pass it to the already existing vectorizer.
|
||||
while (super.hasNext()) {
|
||||
Record fileContents = super.nextRecord();
|
||||
INDArray transform = tfidfVectorizer.transform(fileContents);
|
||||
|
||||
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
|
||||
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
|
||||
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
|
||||
|
||||
if (appendLabel)
|
||||
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
|
||||
|
||||
records.add(record);
|
||||
}
|
||||
|
||||
recordIter = records.iterator();
|
||||
}
|
||||
|
||||
this.initialized = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
if (inputSplit == null)
|
||||
throw new UnsupportedOperationException("Cannot reset without first initializing");
|
||||
recordIter = records.iterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Record nextRecord() {
|
||||
if (recordIter == null)
|
||||
return super.nextRecord();
|
||||
return recordIter.next();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Writable> next() {
|
||||
return nextRecord().getRecord();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
//we aren't done vectorizing yet
|
||||
if (recordIter == null)
|
||||
return super.hasNext();
|
||||
return recordIter.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setConf(Configuration conf) {
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Configuration getConf() {
|
||||
return conf;
|
||||
}
|
||||
|
||||
public TfidfVectorizer getTfidfVectorizer() {
|
||||
return tfidfVectorizer;
|
||||
}
|
||||
|
||||
public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) {
|
||||
if (initialized) {
|
||||
throw new IllegalArgumentException(
|
||||
"Setting TfidfVectorizer after TfidfRecordReader initialization doesn't have an effect");
|
||||
}
|
||||
this.tfidfVectorizer = tfidfVectorizer;
|
||||
}
|
||||
|
||||
public int getNumFeatures() {
|
||||
return numFeatures;
|
||||
}
|
||||
|
||||
public void shuffle() {
|
||||
this.shuffle(new Random());
|
||||
}
|
||||
|
||||
public void shuffle(Random random) {
|
||||
Collections.shuffle(this.records, random);
|
||||
this.reset();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
|
||||
return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
|
||||
List<Record> out = new ArrayList<>();
|
||||
|
||||
for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) {
|
||||
INDArray transform = tfidfVectorizer.transform(fileContents);
|
||||
|
||||
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
|
||||
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
|
||||
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
|
||||
|
||||
if (appendLabel)
|
||||
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
|
||||
out.add(record);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.stopwords;
|
||||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
public class StopWords {
|
||||
|
||||
private static List<String> stopWords;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static List<String> getStopWords() {
|
||||
|
||||
try {
|
||||
if (stopWords == null)
|
||||
stopWords = IOUtils.readLines(StopWords.class.getResourceAsStream("/stopwords"));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return stopWords;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,125 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer;
|
||||
|
||||
import opennlp.tools.tokenize.TokenizerME;
|
||||
import opennlp.tools.tokenize.TokenizerModel;
|
||||
import opennlp.tools.util.Span;
|
||||
import opennlp.uima.tokenize.AbstractTokenizer;
|
||||
import opennlp.uima.tokenize.TokenizerModelResource;
|
||||
import opennlp.uima.util.AnnotatorUtil;
|
||||
import opennlp.uima.util.UimaUtil;
|
||||
import org.apache.uima.UimaContext;
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
||||
import org.apache.uima.cas.CAS;
|
||||
import org.apache.uima.cas.Feature;
|
||||
import org.apache.uima.cas.TypeSystem;
|
||||
import org.apache.uima.cas.text.AnnotationFS;
|
||||
import org.apache.uima.resource.ResourceAccessException;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
|
||||
public class ConcurrentTokenizer extends AbstractTokenizer {
|
||||
|
||||
/**
|
||||
* The OpenNLP tokenizer.
|
||||
*/
|
||||
private TokenizerME tokenizer;
|
||||
|
||||
private Feature probabilityFeature;
|
||||
|
||||
@Override
|
||||
public synchronized void process(CAS cas) throws AnalysisEngineProcessException {
|
||||
super.process(cas);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance.
|
||||
*
|
||||
* Note: Use {@link #initialize(UimaContext) } to initialize
|
||||
* this instance. Not use the constructor.
|
||||
*/
|
||||
public ConcurrentTokenizer() {
|
||||
super("OpenNLP Tokenizer");
|
||||
|
||||
// must not be implemented !
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the current instance with the given context.
|
||||
*
|
||||
* Note: Do all initialization in this method, do not use the constructor.
|
||||
*/
|
||||
public void initialize(UimaContext context) throws ResourceInitializationException {
|
||||
|
||||
super.initialize(context);
|
||||
|
||||
TokenizerModel model;
|
||||
|
||||
try {
|
||||
TokenizerModelResource modelResource =
|
||||
(TokenizerModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
|
||||
|
||||
model = modelResource.getModel();
|
||||
} catch (ResourceAccessException e) {
|
||||
throw new ResourceInitializationException(e);
|
||||
}
|
||||
|
||||
tokenizer = new TokenizerME(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the type system.
|
||||
*/
|
||||
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
|
||||
|
||||
super.typeSystemInit(typeSystem);
|
||||
|
||||
probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(context, tokenType,
|
||||
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected Span[] tokenize(CAS cas, AnnotationFS sentence) {
|
||||
return tokenizer.tokenizePos(sentence.getCoveredText());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) {
|
||||
// if interest
|
||||
if (probabilityFeature != null) {
|
||||
double tokenProbabilties[] = tokenizer.getTokenProbabilities();
|
||||
|
||||
for (int i = 0; i < tokenAnnotations.length; i++) {
|
||||
tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Releases allocated resources.
|
||||
*/
|
||||
public void destroy() {
|
||||
// dereference model to allow garbage collection
|
||||
tokenizer = null;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer;
|
||||
|
||||
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Tokenizer based on the {@link java.io.StreamTokenizer}
|
||||
* @author Adam Gibson
|
||||
*
|
||||
*/
|
||||
public class DefaultStreamTokenizer implements Tokenizer {
|
||||
|
||||
private StreamTokenizer streamTokenizer;
|
||||
private TokenPreProcess tokenPreProcess;
|
||||
|
||||
|
||||
public DefaultStreamTokenizer(InputStream is) {
|
||||
Reader r = new BufferedReader(new InputStreamReader(is));
|
||||
streamTokenizer = new StreamTokenizer(r);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasMoreTokens() {
|
||||
if (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
|
||||
try {
|
||||
streamTokenizer.nextToken();
|
||||
} catch (IOException e1) {
|
||||
throw new RuntimeException(e1);
|
||||
}
|
||||
}
|
||||
return streamTokenizer.ttype != StreamTokenizer.TT_EOF && streamTokenizer.ttype != -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int countTokens() {
|
||||
return getTokens().size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String nextToken() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
|
||||
|
||||
if (streamTokenizer.ttype == StreamTokenizer.TT_WORD) {
|
||||
sb.append(streamTokenizer.sval);
|
||||
} else if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
|
||||
sb.append(streamTokenizer.nval);
|
||||
} else if (streamTokenizer.ttype == StreamTokenizer.TT_EOL) {
|
||||
try {
|
||||
while (streamTokenizer.ttype == StreamTokenizer.TT_EOL)
|
||||
streamTokenizer.nextToken();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
else if (hasMoreTokens())
|
||||
return nextToken();
|
||||
|
||||
|
||||
String ret = sb.toString();
|
||||
|
||||
if (tokenPreProcess != null)
|
||||
ret = tokenPreProcess.preProcess(ret);
|
||||
return ret;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getTokens() {
|
||||
List<String> tokens = new ArrayList<>();
|
||||
while (hasMoreTokens()) {
|
||||
tokens.add(nextToken());
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
||||
this.tokenPreProcess = tokenPreProcessor;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,75 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.StringTokenizer;
|
||||
|
||||
/**
|
||||
* Default tokenizer
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class DefaultTokenizer implements Tokenizer {
|
||||
|
||||
public DefaultTokenizer(String tokens) {
|
||||
tokenizer = new StringTokenizer(tokens);
|
||||
}
|
||||
|
||||
private StringTokenizer tokenizer;
|
||||
private TokenPreProcess tokenPreProcess;
|
||||
|
||||
@Override
|
||||
public boolean hasMoreTokens() {
|
||||
return tokenizer.hasMoreTokens();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int countTokens() {
|
||||
return tokenizer.countTokens();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String nextToken() {
|
||||
String base = tokenizer.nextToken();
|
||||
if (tokenPreProcess != null)
|
||||
base = tokenPreProcess.preProcess(base);
|
||||
return base;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getTokens() {
|
||||
List<String> tokens = new ArrayList<>();
|
||||
while (hasMoreTokens()) {
|
||||
tokens.add(nextToken());
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
||||
this.tokenPreProcess = tokenPreProcessor;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,136 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer;
|
||||
|
||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
||||
import org.apache.uima.cas.CAS;
|
||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
||||
import org.apache.uima.fit.util.JCasUtil;
|
||||
import org.cleartk.token.type.Sentence;
|
||||
import org.cleartk.token.type.Token;
|
||||
import org.datavec.nlp.annotator.PoStagger;
|
||||
import org.datavec.nlp.annotator.SentenceAnnotator;
|
||||
import org.datavec.nlp.annotator.StemmerAnnotator;
|
||||
import org.datavec.nlp.annotator.TokenizerAnnotator;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
public class PosUimaTokenizer implements Tokenizer {
|
||||
|
||||
private static AnalysisEngine engine;
|
||||
private List<String> tokens;
|
||||
private Collection<String> allowedPosTags;
|
||||
private int index;
|
||||
private static CAS cas;
|
||||
|
||||
public PosUimaTokenizer(String tokens, AnalysisEngine engine, Collection<String> allowedPosTags) {
|
||||
if (engine == null)
|
||||
PosUimaTokenizer.engine = engine;
|
||||
this.allowedPosTags = allowedPosTags;
|
||||
this.tokens = new ArrayList<>();
|
||||
try {
|
||||
if (cas == null)
|
||||
cas = engine.newCAS();
|
||||
|
||||
cas.reset();
|
||||
cas.setDocumentText(tokens);
|
||||
PosUimaTokenizer.engine.process(cas);
|
||||
for (Sentence s : JCasUtil.select(cas.getJCas(), Sentence.class)) {
|
||||
for (Token t : JCasUtil.selectCovered(Token.class, s)) {
|
||||
//add NONE for each invalid token
|
||||
if (valid(t))
|
||||
if (t.getLemma() != null)
|
||||
this.tokens.add(t.getLemma());
|
||||
else if (t.getStem() != null)
|
||||
this.tokens.add(t.getStem());
|
||||
else
|
||||
this.tokens.add(t.getCoveredText());
|
||||
else
|
||||
this.tokens.add("NONE");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private boolean valid(Token token) {
|
||||
String check = token.getCoveredText();
|
||||
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
|
||||
return false;
|
||||
else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos()))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasMoreTokens() {
|
||||
return index < tokens.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int countTokens() {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String nextToken() {
|
||||
String ret = tokens.get(index);
|
||||
index++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getTokens() {
|
||||
List<String> tokens = new ArrayList<String>();
|
||||
while (hasMoreTokens()) {
|
||||
tokens.add(nextToken());
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
public static AnalysisEngine defaultAnalysisEngine() {
|
||||
try {
|
||||
return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription(
|
||||
SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(),
|
||||
PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English")));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
||||
// TODO Auto-generated method stub
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer;
|
||||
|
||||
|
||||
public interface TokenPreProcess {
|
||||
|
||||
/**
|
||||
* Pre process a token
|
||||
* @param token the token to pre process
|
||||
* @return the preprocessed token
|
||||
*/
|
||||
String preProcess(String token);
|
||||
|
||||
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface Tokenizer {
|
||||
|
||||
/**
|
||||
* An iterator for tracking whether
|
||||
* more tokens are left in the iterator not
|
||||
* @return whether there is anymore tokens
|
||||
* to iterate over
|
||||
*/
|
||||
boolean hasMoreTokens();
|
||||
|
||||
/**
|
||||
* The number of tokens in the tokenizer
|
||||
* @return the number of tokens
|
||||
*/
|
||||
int countTokens();
|
||||
|
||||
/**
|
||||
* The next token (word usually) in the string
|
||||
* @return the next token in the string if any
|
||||
*/
|
||||
String nextToken();
|
||||
|
||||
/**
|
||||
* Returns a list of all the tokens
|
||||
* @return a list of all the tokens
|
||||
*/
|
||||
List<String> getTokens();
|
||||
|
||||
/**
|
||||
* Set the token pre process
|
||||
* @param tokenPreProcessor the token pre processor to set
|
||||
*/
|
||||
void setTokenPreProcessor(TokenPreProcess tokenPreProcessor);
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,123 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer;
|
||||
|
||||
import org.apache.uima.cas.CAS;
|
||||
import org.apache.uima.fit.util.JCasUtil;
|
||||
import org.cleartk.token.type.Token;
|
||||
import org.datavec.nlp.uima.UimaResource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Tokenizer based on the passed in analysis engine
|
||||
* @author Adam Gibson
|
||||
*
|
||||
*/
|
||||
public class UimaTokenizer implements Tokenizer {
|
||||
|
||||
private List<String> tokens;
|
||||
private int index;
|
||||
private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class);
|
||||
private boolean checkForLabel;
|
||||
private TokenPreProcess tokenPreProcessor;
|
||||
|
||||
|
||||
public UimaTokenizer(String tokens, UimaResource resource, boolean checkForLabel) {
|
||||
|
||||
this.checkForLabel = checkForLabel;
|
||||
this.tokens = new ArrayList<>();
|
||||
try {
|
||||
CAS cas = resource.process(tokens);
|
||||
|
||||
Collection<Token> tokenList = JCasUtil.select(cas.getJCas(), Token.class);
|
||||
|
||||
for (Token t : tokenList) {
|
||||
|
||||
if (!checkForLabel || valid(t.getCoveredText()))
|
||||
if (t.getLemma() != null)
|
||||
this.tokens.add(t.getLemma());
|
||||
else if (t.getStem() != null)
|
||||
this.tokens.add(t.getStem());
|
||||
else
|
||||
this.tokens.add(t.getCoveredText());
|
||||
}
|
||||
|
||||
|
||||
resource.release(cas);
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private boolean valid(String check) {
|
||||
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasMoreTokens() {
|
||||
return index < tokens.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int countTokens() {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String nextToken() {
|
||||
String ret = tokens.get(index);
|
||||
index++;
|
||||
if (tokenPreProcessor != null) {
|
||||
ret = tokenPreProcessor.preProcess(ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getTokens() {
|
||||
List<String> tokens = new ArrayList<>();
|
||||
while (hasMoreTokens()) {
|
||||
tokens.add(nextToken());
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
||||
this.tokenPreProcessor = tokenPreProcessor;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
|
||||
|
||||
|
||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
||||
|
||||
/**
|
||||
* Gets rid of endings:
|
||||
*
|
||||
* ed,ing, ly, s, .
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class EndingPreProcessor implements TokenPreProcess {
|
||||
@Override
|
||||
public String preProcess(String token) {
|
||||
if (token.endsWith("s") && !token.endsWith("ss"))
|
||||
token = token.substring(0, token.length() - 1);
|
||||
if (token.endsWith("."))
|
||||
token = token.substring(0, token.length() - 1);
|
||||
if (token.endsWith("ed"))
|
||||
token = token.substring(0, token.length() - 2);
|
||||
if (token.endsWith("ing"))
|
||||
token = token.substring(0, token.length() - 3);
|
||||
if (token.endsWith("ly"))
|
||||
token = token.substring(0, token.length() - 2);
|
||||
return token;
|
||||
}
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
|
||||
|
||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
||||
|
||||
public class LowerCasePreProcessor implements TokenPreProcess {
|
||||
@Override
|
||||
public String preProcess(String token) {
|
||||
return token.toLowerCase();
|
||||
}
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
||||
|
||||
|
||||
|
||||
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizer.DefaultTokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
/**
|
||||
* Default tokenizer based on string tokenizer or stream tokenizer
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class DefaultTokenizerFactory implements TokenizerFactory {
|
||||
|
||||
private TokenPreProcess tokenPreProcess;
|
||||
|
||||
@Override
|
||||
public Tokenizer create(String toTokenize) {
|
||||
DefaultTokenizer t = new DefaultTokenizer(toTokenize);
|
||||
t.setTokenPreProcessor(tokenPreProcess);
|
||||
return t;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tokenizer create(InputStream toTokenize) {
|
||||
Tokenizer t = new DefaultStreamTokenizer(toTokenize);
|
||||
t.setTokenPreProcessor(tokenPreProcess);
|
||||
return t;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
|
||||
this.tokenPreProcess = preProcessor;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
||||
|
||||
|
||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
||||
import org.datavec.nlp.annotator.PoStagger;
|
||||
import org.datavec.nlp.annotator.SentenceAnnotator;
|
||||
import org.datavec.nlp.annotator.StemmerAnnotator;
|
||||
import org.datavec.nlp.annotator.TokenizerAnnotator;
|
||||
import org.datavec.nlp.tokenization.tokenizer.PosUimaTokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.Collection;
|
||||
|
||||
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine;
|
||||
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
|
||||
|
||||
public class PosUimaTokenizerFactory implements TokenizerFactory {
|
||||
|
||||
private AnalysisEngine tokenizer;
|
||||
private Collection<String> allowedPoSTags;
|
||||
private TokenPreProcess tokenPreProcess;
|
||||
|
||||
|
||||
public PosUimaTokenizerFactory(Collection<String> allowedPoSTags) {
|
||||
this(defaultAnalysisEngine(), allowedPoSTags);
|
||||
}
|
||||
|
||||
public PosUimaTokenizerFactory(AnalysisEngine tokenizer, Collection<String> allowedPosTags) {
|
||||
this.tokenizer = tokenizer;
|
||||
this.allowedPoSTags = allowedPosTags;
|
||||
}
|
||||
|
||||
|
||||
public static AnalysisEngine defaultAnalysisEngine() {
|
||||
try {
|
||||
return createEngine(createEngineDescription(SentenceAnnotator.getDescription(),
|
||||
TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"),
|
||||
StemmerAnnotator.getDescription("English")));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Tokenizer create(String toTokenize) {
|
||||
PosUimaTokenizer t = new PosUimaTokenizer(toTokenize, tokenizer, allowedPoSTags);
|
||||
t.setTokenPreProcessor(tokenPreProcess);
|
||||
return t;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tokenizer create(InputStream toTokenize) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
|
||||
this.tokenPreProcess = preProcessor;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
||||
|
||||
|
||||
|
||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
/**
|
||||
* Generates a tokenizer for a given string
|
||||
* @author Adam Gibson
|
||||
*
|
||||
*/
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface TokenizerFactory {
|
||||
|
||||
/**
|
||||
* The tokenizer to createComplex
|
||||
* @param toTokenize the string to createComplex the tokenizer with
|
||||
* @return the new tokenizer
|
||||
*/
|
||||
Tokenizer create(String toTokenize);
|
||||
|
||||
/**
|
||||
* Create a tokenizer based on an input stream
|
||||
* @param toTokenize
|
||||
* @return
|
||||
*/
|
||||
Tokenizer create(InputStream toTokenize);
|
||||
|
||||
/**
|
||||
* Sets a token pre processor to be used
|
||||
* with every tokenizer
|
||||
* @param preProcessor the token pre processor to use
|
||||
*/
|
||||
void setTokenPreProcessor(TokenPreProcess preProcessor);
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,138 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
||||
|
||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.datavec.nlp.annotator.SentenceAnnotator;
|
||||
import org.datavec.nlp.annotator.TokenizerAnnotator;
|
||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizer.UimaTokenizer;
|
||||
import org.datavec.nlp.uima.UimaResource;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
|
||||
/**
|
||||
* Uses a uima {@link AnalysisEngine} to
|
||||
* tokenize text.
|
||||
*
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*
|
||||
*/
|
||||
public class UimaTokenizerFactory implements TokenizerFactory {
|
||||
|
||||
|
||||
private UimaResource uimaResource;
|
||||
private boolean checkForLabel;
|
||||
private static AnalysisEngine defaultAnalysisEngine;
|
||||
private TokenPreProcess preProcess;
|
||||
|
||||
public UimaTokenizerFactory() throws ResourceInitializationException {
|
||||
this(defaultAnalysisEngine(), true);
|
||||
}
|
||||
|
||||
|
||||
public UimaTokenizerFactory(UimaResource resource) {
|
||||
this(resource, true);
|
||||
}
|
||||
|
||||
|
||||
public UimaTokenizerFactory(AnalysisEngine tokenizer) {
|
||||
this(tokenizer, true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public UimaTokenizerFactory(UimaResource resource, boolean checkForLabel) {
|
||||
this.uimaResource = resource;
|
||||
this.checkForLabel = checkForLabel;
|
||||
}
|
||||
|
||||
public UimaTokenizerFactory(boolean checkForLabel) throws ResourceInitializationException {
|
||||
this(defaultAnalysisEngine(), checkForLabel);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public UimaTokenizerFactory(AnalysisEngine tokenizer, boolean checkForLabel) {
|
||||
super();
|
||||
this.checkForLabel = checkForLabel;
|
||||
try {
|
||||
this.uimaResource = new UimaResource(tokenizer);
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public Tokenizer create(String toTokenize) {
|
||||
if (toTokenize == null || toTokenize.isEmpty())
|
||||
throw new IllegalArgumentException("Unable to proceed; on sentence to tokenize");
|
||||
Tokenizer ret = new UimaTokenizer(toTokenize, uimaResource, checkForLabel);
|
||||
ret.setTokenPreProcessor(preProcess);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
public UimaResource getUimaResource() {
|
||||
return uimaResource;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Creates a tokenization,/stemming pipeline
|
||||
* @return a tokenization/stemming pipeline
|
||||
*/
|
||||
public static AnalysisEngine defaultAnalysisEngine() {
|
||||
try {
|
||||
if (defaultAnalysisEngine == null)
|
||||
|
||||
defaultAnalysisEngine = AnalysisEngineFactory.createEngine(
|
||||
AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.getDescription(),
|
||||
TokenizerAnnotator.getDescription()));
|
||||
|
||||
return defaultAnalysisEngine;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Tokenizer create(InputStream toTokenize) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
|
||||
this.preProcess = preProcessor;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.transforms;
|
||||
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface BagOfWordsTransform extends Transform {
|
||||
|
||||
|
||||
/**
|
||||
* The output shape of the transform (usually 1 x number of words)
|
||||
* @return
|
||||
*/
|
||||
long[] outputShape();
|
||||
|
||||
/**
|
||||
* The vocab words in the transform.
|
||||
* This is the words that were accumulated
|
||||
* when building a vocabulary.
|
||||
* (This is generally associated with some form of
|
||||
* mininmum words frequency scanning to build a vocab
|
||||
* you then map on to a list of vocab words as a list)
|
||||
* @return the vocab words for the transform
|
||||
*/
|
||||
List<String> vocabWords();
|
||||
|
||||
/**
|
||||
* Transform for a list of tokens
|
||||
* that are objects. This is to allow loose
|
||||
* typing for tokens that are unique (non string)
|
||||
* @param tokens the token objects to transform
|
||||
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
|
||||
*/
|
||||
INDArray transformFromObject(List<List<Object>> tokens);
|
||||
|
||||
|
||||
/**
|
||||
* Transform for a list of tokens
|
||||
* that are {@link Writable} (Generally {@link org.datavec.api.writable.Text}
|
||||
* @param tokens the token objects to transform
|
||||
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
|
||||
*/
|
||||
INDArray transformFrom(List<List<Writable>> tokens);
|
||||
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.transforms;
|
||||
|
||||
public class BaseWordMapTransform {
|
||||
}
|
|
@ -1,146 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.transforms;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonIgnoreProperties({"gazeteer"})
|
||||
public class GazeteerTransform extends BaseColumnTransform implements BagOfWordsTransform {
|
||||
|
||||
private String newColumnName;
|
||||
private List<String> wordList;
|
||||
private Set<String> gazeteer;
|
||||
|
||||
@JsonCreator
|
||||
public GazeteerTransform(@JsonProperty("columnName") String columnName,
|
||||
@JsonProperty("newColumnName")String newColumnName,
|
||||
@JsonProperty("wordList") List<String> wordList) {
|
||||
super(columnName);
|
||||
this.newColumnName = newColumnName;
|
||||
this.wordList = wordList;
|
||||
this.gazeteer = new HashSet<>(wordList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
|
||||
return new NDArrayMetaData(newName,new long[]{wordList.size()});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Writable map(Writable columnWritable) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object mapSequence(Object sequence) {
|
||||
List<List<Object>> sequenceInput = (List<List<Object>>) sequence;
|
||||
INDArray ret = Nd4j.create(DataType.FLOAT, wordList.size());
|
||||
|
||||
for(List<Object> list : sequenceInput) {
|
||||
for(Object token : list) {
|
||||
String s = token.toString();
|
||||
if(gazeteer.contains(s)) {
|
||||
ret.putScalar(wordList.indexOf(s),1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
|
||||
INDArray arr = (INDArray) mapSequence((Object) sequence);
|
||||
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(arr)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return newColumnName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object map(Object input) {
|
||||
return gazeteer.contains(input.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String outputColumnName() {
|
||||
return newColumnName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] outputColumnNames() {
|
||||
return new String[]{newColumnName};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] columnNames() {
|
||||
return new String[]{columnName()};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String columnName() {
|
||||
return columnName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] outputShape() {
|
||||
return new long[]{wordList.size()};
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> vocabWords() {
|
||||
return wordList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transformFromObject(List<List<Object>> tokens) {
|
||||
return (INDArray) mapSequence(tokens);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transformFrom(List<List<Writable>> tokens) {
|
||||
return (INDArray) mapSequence((Object) tokens);
|
||||
}
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
|
||||
package org.datavec.nlp.transforms;
|
||||
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.list.NDArrayList;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform {
|
||||
|
||||
private BagOfWordsTransform[] transforms;
|
||||
private String newColumnName;
|
||||
private List<String> vocabWords;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param columnName
|
||||
* @param transforms
|
||||
* @param newColumnName
|
||||
*/
|
||||
@JsonCreator
|
||||
public MultiNlpTransform(@JsonProperty("columnName") String columnName,
|
||||
@JsonProperty("transforms") BagOfWordsTransform[] transforms,
|
||||
@JsonProperty("newColumnName") String newColumnName) {
|
||||
super(columnName);
|
||||
this.transforms = transforms;
|
||||
this.vocabWords = transforms[0].vocabWords();
|
||||
if(transforms.length > 1) {
|
||||
for(int i = 1; i < transforms.length; i++) {
|
||||
if(!transforms[i].vocabWords().equals(vocabWords)) {
|
||||
throw new IllegalArgumentException("Vocab words not consistent across transforms!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.newColumnName = newColumnName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object mapSequence(Object sequence) {
|
||||
NDArrayList ndArrayList = new NDArrayList();
|
||||
for(BagOfWordsTransform bagofWordsTransform : transforms) {
|
||||
ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List<List<Object>>) sequence)));
|
||||
}
|
||||
|
||||
return ndArrayList.array();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
|
||||
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(transformFrom(sequence))));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
|
||||
return new NDArrayMetaData(newName,outputShape());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Writable map(Writable columnWritable) {
|
||||
throw new UnsupportedOperationException("Only able to add for time series");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return newColumnName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object map(Object input) {
|
||||
throw new UnsupportedOperationException("Only able to add for time series");
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] outputShape() {
|
||||
long[] ret = new long[transforms[0].outputShape().length];
|
||||
int validatedRank = transforms[0].outputShape().length;
|
||||
for(int i = 1; i < transforms.length; i++) {
|
||||
if(transforms[i].outputShape().length != validatedRank) {
|
||||
throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank);
|
||||
}
|
||||
}
|
||||
for(int i = 0; i < transforms.length; i++) {
|
||||
for(int j = 0; j < validatedRank; j++)
|
||||
ret[j] += transforms[i].outputShape()[j];
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> vocabWords() {
|
||||
return vocabWords;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transformFromObject(List<List<Object>> tokens) {
|
||||
NDArrayList ndArrayList = new NDArrayList();
|
||||
for(BagOfWordsTransform bagofWordsTransform : transforms) {
|
||||
INDArray arr2 = bagofWordsTransform.transformFromObject(tokens);
|
||||
arr2 = arr2.reshape(arr2.length());
|
||||
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
|
||||
ndArrayList.addAll(newList); }
|
||||
|
||||
return ndArrayList.array();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transformFrom(List<List<Writable>> tokens) {
|
||||
NDArrayList ndArrayList = new NDArrayList();
|
||||
for(BagOfWordsTransform bagofWordsTransform : transforms) {
|
||||
INDArray arr2 = bagofWordsTransform.transformFrom(tokens);
|
||||
arr2 = arr2.reshape(arr2.length());
|
||||
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
|
||||
ndArrayList.addAll(newList);
|
||||
}
|
||||
|
||||
return ndArrayList.array();
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,226 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.transforms;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
||||
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Counter;
|
||||
import org.nd4j.common.util.MathUtils;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true, exclude = {"tokenizerFactory"})
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform {
|
||||
|
||||
private String newColumName;
|
||||
private Map<String,Integer> wordIndexMap;
|
||||
private Map<String,Double> weightMap;
|
||||
private boolean exceptionOnUnknown;
|
||||
private String tokenizerFactoryClass;
|
||||
private String preprocessorClass;
|
||||
private TokenizerFactory tokenizerFactory;
|
||||
|
||||
@JsonCreator
|
||||
public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String columnName,
|
||||
@JsonProperty("newColumnName") String newColumnName,
|
||||
@JsonProperty("wordIndexMap") Map<String,Integer> wordIndexMap,
|
||||
@JsonProperty("idfMap") Map<String,Double> idfMap,
|
||||
@JsonProperty("exceptionOnUnknown") boolean exceptionOnUnknown,
|
||||
@JsonProperty("tokenizerFactoryClass") String tokenizerFactoryClass,
|
||||
@JsonProperty("preprocessorClass") String preprocessorClass) {
|
||||
super(columnName);
|
||||
this.newColumName = newColumnName;
|
||||
this.wordIndexMap = wordIndexMap;
|
||||
this.exceptionOnUnknown = exceptionOnUnknown;
|
||||
this.weightMap = idfMap;
|
||||
this.tokenizerFactoryClass = tokenizerFactoryClass;
|
||||
this.preprocessorClass = preprocessorClass;
|
||||
if(this.tokenizerFactoryClass == null) {
|
||||
this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName();
|
||||
}
|
||||
try {
|
||||
tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
|
||||
}
|
||||
|
||||
if(preprocessorClass != null){
|
||||
try {
|
||||
TokenPreProcess tpp = (TokenPreProcess) Class.forName(this.preprocessorClass).newInstance();
|
||||
tokenizerFactory.setTokenPreProcessor(tpp);
|
||||
} catch (Exception e){
|
||||
throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<Writable> map(List<Writable> writables) {
|
||||
Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName));
|
||||
List<Writable> ret = new ArrayList<>(writables);
|
||||
ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString())));
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object map(Object input) {
|
||||
return convert(input.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object mapSequence(Object sequence) {
|
||||
return convert(sequence.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Schema transform(Schema inputSchema) {
|
||||
Schema.Builder newSchema = new Schema.Builder();
|
||||
for(int i = 0; i < inputSchema.numColumns(); i++) {
|
||||
if(inputSchema.getName(i).equals(this.columnName)) {
|
||||
newSchema.addColumnNDArray(newColumName,new long[]{1,wordIndexMap.size()});
|
||||
}
|
||||
else {
|
||||
newSchema.addColumn(inputSchema.getMetaData(i));
|
||||
}
|
||||
}
|
||||
|
||||
return newSchema.build();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert the given text
|
||||
* in to an {@link INDArray}
|
||||
* using the {@link TokenizerFactory}
|
||||
* specified in the constructor.
|
||||
* @param text the text to transform
|
||||
* @return the created {@link INDArray}
|
||||
* based on the {@link #wordIndexMap} for the column indices
|
||||
* of the word.
|
||||
*/
|
||||
public INDArray convert(String text) {
|
||||
Tokenizer tokenizer = tokenizerFactory.create(text);
|
||||
List<String> tokens = tokenizer.getTokens();
|
||||
INDArray create = Nd4j.create(1,wordIndexMap.size());
|
||||
Counter<String> tokenizedCounter = new Counter<>();
|
||||
|
||||
for(int i = 0; i < tokens.size(); i++) {
|
||||
tokenizedCounter.incrementCount(tokens.get(i),1.0);
|
||||
}
|
||||
|
||||
for(int i = 0; i < tokens.size(); i++) {
|
||||
if(wordIndexMap.containsKey(tokens.get(i))) {
|
||||
int idx = wordIndexMap.get(tokens.get(i));
|
||||
int count = (int) tokenizedCounter.getCount(tokens.get(i));
|
||||
double weight = tfidfWord(tokens.get(i),count,tokens.size());
|
||||
create.putScalar(idx,weight);
|
||||
}
|
||||
}
|
||||
|
||||
return create;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Calculate the tifdf for a word
|
||||
* given the word, word count, and document length
|
||||
* @param word the word to calculate
|
||||
* @param wordCount the word frequency
|
||||
* @param documentLength the number of words in the document
|
||||
* @return the tfidf weight for a given word
|
||||
*/
|
||||
public double tfidfWord(String word, long wordCount, long documentLength) {
|
||||
double tf = tfForWord(wordCount, documentLength);
|
||||
double idf = idfForWord(word);
|
||||
return MathUtils.tfidf(tf, idf);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the weight term frequency for a given
|
||||
* word normalized by the dcoument length
|
||||
* @param wordCount the word frequency
|
||||
* @param documentLength the number of words in the edocument
|
||||
* @return
|
||||
*/
|
||||
private double tfForWord(long wordCount, long documentLength) {
|
||||
return wordCount;
|
||||
}
|
||||
|
||||
private double idfForWord(String word) {
|
||||
if(weightMap.containsKey(word))
|
||||
return weightMap.get(word);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
|
||||
return new NDArrayMetaData(outputColumnName(),new long[]{1,wordIndexMap.size()});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String outputColumnName() {
|
||||
return newColumName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] outputColumnNames() {
|
||||
return new String[]{newColumName};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] columnNames() {
|
||||
return new String[]{columnName()};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String columnName() {
|
||||
return columnName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Writable map(Writable columnWritable) {
|
||||
return new NDArrayWritable(convert(columnWritable.toString()));
|
||||
}
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.nlp.uima;
|
||||
|
||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
||||
import org.apache.uima.cas.CAS;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.apache.uima.util.CasPool;
|
||||
|
||||
public class UimaResource {
|
||||
|
||||
private AnalysisEngine analysisEngine;
|
||||
private CasPool casPool;
|
||||
|
||||
public UimaResource(AnalysisEngine analysisEngine) throws ResourceInitializationException {
|
||||
this.analysisEngine = analysisEngine;
|
||||
this.casPool = new CasPool(Runtime.getRuntime().availableProcessors() * 10, analysisEngine);
|
||||
|
||||
}
|
||||
|
||||
public UimaResource(AnalysisEngine analysisEngine, CasPool casPool) {
|
||||
this.analysisEngine = analysisEngine;
|
||||
this.casPool = casPool;
|
||||
|
||||
}
|
||||
|
||||
|
||||
public AnalysisEngine getAnalysisEngine() {
|
||||
return analysisEngine;
|
||||
}
|
||||
|
||||
|
||||
public void setAnalysisEngine(AnalysisEngine analysisEngine) {
|
||||
this.analysisEngine = analysisEngine;
|
||||
}
|
||||
|
||||
|
||||
public CasPool getCasPool() {
|
||||
return casPool;
|
||||
}
|
||||
|
||||
|
||||
public void setCasPool(CasPool casPool) {
|
||||
this.casPool = casPool;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Use the given analysis engine and process the given text
|
||||
* You must release the return cas yourself
|
||||
* @param text the text to rpocess
|
||||
* @return the processed cas
|
||||
*/
|
||||
public CAS process(String text) {
|
||||
CAS cas = retrieve();
|
||||
|
||||
cas.setDocumentText(text);
|
||||
try {
|
||||
analysisEngine.process(cas);
|
||||
} catch (AnalysisEngineProcessException e) {
|
||||
if (text != null && !text.isEmpty())
|
||||
return process(text);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return cas;
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
public CAS retrieve() {
|
||||
CAS ret = casPool.getCas();
|
||||
try {
|
||||
return ret == null ? analysisEngine.newCAS() : ret;
|
||||
} catch (ResourceInitializationException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public void release(CAS cas) {
|
||||
casPool.releaseCas(cas);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue