Remove datavec audio/nlp
parent
ee06fdd16f
commit
8e8a5ec369
|
@ -1,77 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-data-audio</artifactId>
|
|
||||||
|
|
||||||
<name>datavec-data-audio</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-api</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>javacpp</artifactId>
|
|
||||||
<version>${javacpp.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>javacv</artifactId>
|
|
||||||
<version>${javacv.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.github.wendykierp</groupId>
|
|
||||||
<artifactId>JTransforms</artifactId>
|
|
||||||
<version>${jtransforms.version}</version>
|
|
||||||
<classifier>with-dependencies</classifier>
|
|
||||||
</dependency>
|
|
||||||
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
|
|
||||||
<!--
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>ffmpeg-platform</artifactId>
|
|
||||||
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
|
|
||||||
</dependency>
|
|
||||||
-->
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,329 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio;
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.audio.extension.NormalizedSampleAmplitudes;
|
|
||||||
import org.datavec.audio.extension.Spectrogram;
|
|
||||||
import org.datavec.audio.fingerprint.FingerprintManager;
|
|
||||||
import org.datavec.audio.fingerprint.FingerprintSimilarity;
|
|
||||||
import org.datavec.audio.fingerprint.FingerprintSimilarityComputer;
|
|
||||||
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
public class Wave implements Serializable {
|
|
||||||
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
private WaveHeader waveHeader;
|
|
||||||
private byte[] data; // little endian
|
|
||||||
private byte[] fingerprint;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public Wave() {
|
|
||||||
this.waveHeader = new WaveHeader();
|
|
||||||
this.data = new byte[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*
|
|
||||||
* @param filename
|
|
||||||
* Wave file
|
|
||||||
*/
|
|
||||||
public Wave(String filename) {
|
|
||||||
try {
|
|
||||||
InputStream inputStream = new FileInputStream(filename);
|
|
||||||
initWaveWithInputStream(inputStream);
|
|
||||||
inputStream.close();
|
|
||||||
} catch (IOException e) {
|
|
||||||
System.out.println(e.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*
|
|
||||||
* @param inputStream
|
|
||||||
* Wave file input stream
|
|
||||||
*/
|
|
||||||
public Wave(InputStream inputStream) {
|
|
||||||
initWaveWithInputStream(inputStream);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*
|
|
||||||
* @param waveHeader
|
|
||||||
* @param data
|
|
||||||
*/
|
|
||||||
public Wave(WaveHeader waveHeader, byte[] data) {
|
|
||||||
this.waveHeader = waveHeader;
|
|
||||||
this.data = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void initWaveWithInputStream(InputStream inputStream) {
|
|
||||||
// reads the first 44 bytes for header
|
|
||||||
waveHeader = new WaveHeader(inputStream);
|
|
||||||
|
|
||||||
if (waveHeader.isValid()) {
|
|
||||||
// load data
|
|
||||||
try {
|
|
||||||
data = new byte[inputStream.available()];
|
|
||||||
inputStream.read(data);
|
|
||||||
} catch (IOException e) {
|
|
||||||
System.err.println(e.toString());
|
|
||||||
}
|
|
||||||
// end load data
|
|
||||||
} else {
|
|
||||||
System.err.println("Invalid Wave Header");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Trim the wave data
|
|
||||||
*
|
|
||||||
* @param leftTrimNumberOfSample
|
|
||||||
* Number of sample trimmed from beginning
|
|
||||||
* @param rightTrimNumberOfSample
|
|
||||||
* Number of sample trimmed from ending
|
|
||||||
*/
|
|
||||||
public void trim(int leftTrimNumberOfSample, int rightTrimNumberOfSample) {
|
|
||||||
|
|
||||||
long chunkSize = waveHeader.getChunkSize();
|
|
||||||
long subChunk2Size = waveHeader.getSubChunk2Size();
|
|
||||||
|
|
||||||
long totalTrimmed = leftTrimNumberOfSample + rightTrimNumberOfSample;
|
|
||||||
|
|
||||||
if (totalTrimmed > subChunk2Size) {
|
|
||||||
leftTrimNumberOfSample = (int) subChunk2Size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// update wav info
|
|
||||||
chunkSize -= totalTrimmed;
|
|
||||||
subChunk2Size -= totalTrimmed;
|
|
||||||
|
|
||||||
if (chunkSize >= 0 && subChunk2Size >= 0) {
|
|
||||||
waveHeader.setChunkSize(chunkSize);
|
|
||||||
waveHeader.setSubChunk2Size(subChunk2Size);
|
|
||||||
|
|
||||||
byte[] trimmedData = new byte[(int) subChunk2Size];
|
|
||||||
System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size);
|
|
||||||
data = trimmedData;
|
|
||||||
} else {
|
|
||||||
System.err.println("Trim error: Negative length");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Trim the wave data from beginning
|
|
||||||
*
|
|
||||||
* @param numberOfSample
|
|
||||||
* numberOfSample trimmed from beginning
|
|
||||||
*/
|
|
||||||
public void leftTrim(int numberOfSample) {
|
|
||||||
trim(numberOfSample, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Trim the wave data from ending
|
|
||||||
*
|
|
||||||
* @param numberOfSample
|
|
||||||
* numberOfSample trimmed from ending
|
|
||||||
*/
|
|
||||||
public void rightTrim(int numberOfSample) {
|
|
||||||
trim(0, numberOfSample);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Trim the wave data
|
|
||||||
*
|
|
||||||
* @param leftTrimSecond
|
|
||||||
* Seconds trimmed from beginning
|
|
||||||
* @param rightTrimSecond
|
|
||||||
* Seconds trimmed from ending
|
|
||||||
*/
|
|
||||||
public void trim(double leftTrimSecond, double rightTrimSecond) {
|
|
||||||
|
|
||||||
int sampleRate = waveHeader.getSampleRate();
|
|
||||||
int bitsPerSample = waveHeader.getBitsPerSample();
|
|
||||||
int channels = waveHeader.getChannels();
|
|
||||||
|
|
||||||
int leftTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * leftTrimSecond);
|
|
||||||
int rightTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * rightTrimSecond);
|
|
||||||
|
|
||||||
trim(leftTrimNumberOfSample, rightTrimNumberOfSample);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Trim the wave data from beginning
|
|
||||||
*
|
|
||||||
* @param second
|
|
||||||
* Seconds trimmed from beginning
|
|
||||||
*/
|
|
||||||
public void leftTrim(double second) {
|
|
||||||
trim(second, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Trim the wave data from ending
|
|
||||||
*
|
|
||||||
* @param second
|
|
||||||
* Seconds trimmed from ending
|
|
||||||
*/
|
|
||||||
public void rightTrim(double second) {
|
|
||||||
trim(0, second);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the wave header
|
|
||||||
*
|
|
||||||
* @return waveHeader
|
|
||||||
*/
|
|
||||||
public WaveHeader getWaveHeader() {
|
|
||||||
return waveHeader;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the wave spectrogram
|
|
||||||
*
|
|
||||||
* @return spectrogram
|
|
||||||
*/
|
|
||||||
public Spectrogram getSpectrogram() {
|
|
||||||
return new Spectrogram(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the wave spectrogram
|
|
||||||
*
|
|
||||||
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
|
|
||||||
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
|
|
||||||
*
|
|
||||||
* @return spectrogram
|
|
||||||
*/
|
|
||||||
public Spectrogram getSpectrogram(int fftSampleSize, int overlapFactor) {
|
|
||||||
return new Spectrogram(this, fftSampleSize, overlapFactor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the wave data in bytes
|
|
||||||
*
|
|
||||||
* @return wave data
|
|
||||||
*/
|
|
||||||
public byte[] getBytes() {
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Data byte size of the wave excluding header size
|
|
||||||
*
|
|
||||||
* @return byte size of the wave
|
|
||||||
*/
|
|
||||||
public int size() {
|
|
||||||
return data.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Length of the wave in second
|
|
||||||
*
|
|
||||||
* @return length in second
|
|
||||||
*/
|
|
||||||
public float length() {
|
|
||||||
return (float) waveHeader.getSubChunk2Size() / waveHeader.getByteRate();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Timestamp of the wave length
|
|
||||||
*
|
|
||||||
* @return timestamp
|
|
||||||
*/
|
|
||||||
public String timestamp() {
|
|
||||||
float totalSeconds = this.length();
|
|
||||||
float second = totalSeconds % 60;
|
|
||||||
int minute = (int) totalSeconds / 60 % 60;
|
|
||||||
int hour = (int) (totalSeconds / 3600);
|
|
||||||
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
if (hour > 0) {
|
|
||||||
sb.append(hour + ":");
|
|
||||||
}
|
|
||||||
if (minute > 0) {
|
|
||||||
sb.append(minute + ":");
|
|
||||||
}
|
|
||||||
sb.append(second);
|
|
||||||
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the amplitudes of the wave samples (depends on the header)
|
|
||||||
*
|
|
||||||
* @return amplitudes array (signed 16-bit)
|
|
||||||
*/
|
|
||||||
public short[] getSampleAmplitudes() {
|
|
||||||
int bytePerSample = waveHeader.getBitsPerSample() / 8;
|
|
||||||
int numSamples = data.length / bytePerSample;
|
|
||||||
short[] amplitudes = new short[numSamples];
|
|
||||||
|
|
||||||
int pointer = 0;
|
|
||||||
for (int i = 0; i < numSamples; i++) {
|
|
||||||
short amplitude = 0;
|
|
||||||
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
|
|
||||||
// little endian
|
|
||||||
amplitude |= (short) ((data[pointer++] & 0xFF) << (byteNumber * 8));
|
|
||||||
}
|
|
||||||
amplitudes[i] = amplitude;
|
|
||||||
}
|
|
||||||
|
|
||||||
return amplitudes;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String toString() {
|
|
||||||
StringBuilder sb = new StringBuilder(waveHeader.toString());
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("length: " + timestamp());
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] getNormalizedAmplitudes() {
|
|
||||||
NormalizedSampleAmplitudes amplitudes = new NormalizedSampleAmplitudes(this);
|
|
||||||
return amplitudes.getNormalizedAmplitudes();
|
|
||||||
}
|
|
||||||
|
|
||||||
public byte[] getFingerprint() {
|
|
||||||
if (fingerprint == null) {
|
|
||||||
FingerprintManager fingerprintManager = new FingerprintManager();
|
|
||||||
fingerprint = fingerprintManager.extractFingerprint(this);
|
|
||||||
}
|
|
||||||
return fingerprint;
|
|
||||||
}
|
|
||||||
|
|
||||||
public FingerprintSimilarity getFingerprintSimilarity(Wave wave) {
|
|
||||||
FingerprintSimilarityComputer fingerprintSimilarityComputer =
|
|
||||||
new FingerprintSimilarityComputer(this.getFingerprint(), wave.getFingerprint());
|
|
||||||
return fingerprintSimilarityComputer.getFingerprintsSimilarity();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,98 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.io.FileOutputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class WaveFileManager {
|
|
||||||
|
|
||||||
private Wave wave;
|
|
||||||
|
|
||||||
public WaveFileManager() {
|
|
||||||
wave = new Wave();
|
|
||||||
}
|
|
||||||
|
|
||||||
public WaveFileManager(Wave wave) {
|
|
||||||
setWave(wave);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save the wave file
|
|
||||||
*
|
|
||||||
* @param filename
|
|
||||||
* filename to be saved
|
|
||||||
*
|
|
||||||
* @see Wave file saved
|
|
||||||
*/
|
|
||||||
public void saveWaveAsFile(String filename) {
|
|
||||||
|
|
||||||
WaveHeader waveHeader = wave.getWaveHeader();
|
|
||||||
|
|
||||||
int byteRate = waveHeader.getByteRate();
|
|
||||||
int audioFormat = waveHeader.getAudioFormat();
|
|
||||||
int sampleRate = waveHeader.getSampleRate();
|
|
||||||
int bitsPerSample = waveHeader.getBitsPerSample();
|
|
||||||
int channels = waveHeader.getChannels();
|
|
||||||
long chunkSize = waveHeader.getChunkSize();
|
|
||||||
long subChunk1Size = waveHeader.getSubChunk1Size();
|
|
||||||
long subChunk2Size = waveHeader.getSubChunk2Size();
|
|
||||||
int blockAlign = waveHeader.getBlockAlign();
|
|
||||||
|
|
||||||
try {
|
|
||||||
FileOutputStream fos = new FileOutputStream(filename);
|
|
||||||
fos.write(WaveHeader.RIFF_HEADER.getBytes());
|
|
||||||
// little endian
|
|
||||||
fos.write(new byte[] {(byte) (chunkSize), (byte) (chunkSize >> 8), (byte) (chunkSize >> 16),
|
|
||||||
(byte) (chunkSize >> 24)});
|
|
||||||
fos.write(WaveHeader.WAVE_HEADER.getBytes());
|
|
||||||
fos.write(WaveHeader.FMT_HEADER.getBytes());
|
|
||||||
fos.write(new byte[] {(byte) (subChunk1Size), (byte) (subChunk1Size >> 8), (byte) (subChunk1Size >> 16),
|
|
||||||
(byte) (subChunk1Size >> 24)});
|
|
||||||
fos.write(new byte[] {(byte) (audioFormat), (byte) (audioFormat >> 8)});
|
|
||||||
fos.write(new byte[] {(byte) (channels), (byte) (channels >> 8)});
|
|
||||||
fos.write(new byte[] {(byte) (sampleRate), (byte) (sampleRate >> 8), (byte) (sampleRate >> 16),
|
|
||||||
(byte) (sampleRate >> 24)});
|
|
||||||
fos.write(new byte[] {(byte) (byteRate), (byte) (byteRate >> 8), (byte) (byteRate >> 16),
|
|
||||||
(byte) (byteRate >> 24)});
|
|
||||||
fos.write(new byte[] {(byte) (blockAlign), (byte) (blockAlign >> 8)});
|
|
||||||
fos.write(new byte[] {(byte) (bitsPerSample), (byte) (bitsPerSample >> 8)});
|
|
||||||
fos.write(WaveHeader.DATA_HEADER.getBytes());
|
|
||||||
fos.write(new byte[] {(byte) (subChunk2Size), (byte) (subChunk2Size >> 8), (byte) (subChunk2Size >> 16),
|
|
||||||
(byte) (subChunk2Size >> 24)});
|
|
||||||
fos.write(wave.getBytes());
|
|
||||||
fos.close();
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public Wave getWave() {
|
|
||||||
return wave;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setWave(Wave wave) {
|
|
||||||
this.wave = wave;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,281 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class WaveHeader {
|
|
||||||
|
|
||||||
public static final String RIFF_HEADER = "RIFF";
|
|
||||||
public static final String WAVE_HEADER = "WAVE";
|
|
||||||
public static final String FMT_HEADER = "fmt ";
|
|
||||||
public static final String DATA_HEADER = "data";
|
|
||||||
public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header
|
|
||||||
|
|
||||||
private boolean valid;
|
|
||||||
private String chunkId; // 4 bytes
|
|
||||||
private long chunkSize; // unsigned 4 bytes, little endian
|
|
||||||
private String format; // 4 bytes
|
|
||||||
private String subChunk1Id; // 4 bytes
|
|
||||||
private long subChunk1Size; // unsigned 4 bytes, little endian
|
|
||||||
private int audioFormat; // unsigned 2 bytes, little endian
|
|
||||||
private int channels; // unsigned 2 bytes, little endian
|
|
||||||
private long sampleRate; // unsigned 4 bytes, little endian
|
|
||||||
private long byteRate; // unsigned 4 bytes, little endian
|
|
||||||
private int blockAlign; // unsigned 2 bytes, little endian
|
|
||||||
private int bitsPerSample; // unsigned 2 bytes, little endian
|
|
||||||
private String subChunk2Id; // 4 bytes
|
|
||||||
private long subChunk2Size; // unsigned 4 bytes, little endian
|
|
||||||
|
|
||||||
public WaveHeader() {
|
|
||||||
// init a 8k 16bit mono wav
|
|
||||||
chunkSize = 36;
|
|
||||||
subChunk1Size = 16;
|
|
||||||
audioFormat = 1;
|
|
||||||
channels = 1;
|
|
||||||
sampleRate = 8000;
|
|
||||||
byteRate = 16000;
|
|
||||||
blockAlign = 2;
|
|
||||||
bitsPerSample = 16;
|
|
||||||
subChunk2Size = 0;
|
|
||||||
valid = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
public WaveHeader(InputStream inputStream) {
|
|
||||||
valid = loadHeader(inputStream);
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean loadHeader(InputStream inputStream) {
|
|
||||||
|
|
||||||
byte[] headerBuffer = new byte[HEADER_BYTE_LENGTH];
|
|
||||||
try {
|
|
||||||
inputStream.read(headerBuffer);
|
|
||||||
|
|
||||||
// read header
|
|
||||||
int pointer = 0;
|
|
||||||
chunkId = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
|
|
||||||
headerBuffer[pointer++]});
|
|
||||||
// little endian
|
|
||||||
chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff << 24);
|
|
||||||
format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
|
|
||||||
headerBuffer[pointer++]});
|
|
||||||
subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
|
|
||||||
headerBuffer[pointer++], headerBuffer[pointer++]});
|
|
||||||
subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
|
||||||
audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
|
||||||
channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
|
||||||
sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
|
||||||
byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
|
||||||
blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
|
||||||
bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
|
|
||||||
subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
|
|
||||||
headerBuffer[pointer++], headerBuffer[pointer++]});
|
|
||||||
subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 16
|
|
||||||
| (long) (headerBuffer[pointer++] & 0xff) << 24;
|
|
||||||
// end read header
|
|
||||||
|
|
||||||
// the inputStream should be closed outside this method
|
|
||||||
|
|
||||||
// dis.close();
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("",e);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bitsPerSample != 8 && bitsPerSample != 16) {
|
|
||||||
System.err.println("WaveHeader: only supports bitsPerSample 8 or 16");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check the format is support
|
|
||||||
if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
System.err.println("WaveHeader: Unsupported header format");
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isValid() {
|
|
||||||
return valid;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getChunkId() {
|
|
||||||
return chunkId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getChunkSize() {
|
|
||||||
return chunkSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getFormat() {
|
|
||||||
return format;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getSubChunk1Id() {
|
|
||||||
return subChunk1Id;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getSubChunk1Size() {
|
|
||||||
return subChunk1Size;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getAudioFormat() {
|
|
||||||
return audioFormat;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getChannels() {
|
|
||||||
return channels;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getSampleRate() {
|
|
||||||
return (int) sampleRate;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getByteRate() {
|
|
||||||
return (int) byteRate;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getBlockAlign() {
|
|
||||||
return blockAlign;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getBitsPerSample() {
|
|
||||||
return bitsPerSample;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getSubChunk2Id() {
|
|
||||||
return subChunk2Id;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getSubChunk2Size() {
|
|
||||||
return subChunk2Size;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSampleRate(int sampleRate) {
|
|
||||||
int newSubChunk2Size = (int) (this.subChunk2Size * sampleRate / this.sampleRate);
|
|
||||||
// if num bytes for each sample is even, the size of newSubChunk2Size also needed to be in even number
|
|
||||||
if ((bitsPerSample / 8) % 2 == 0) {
|
|
||||||
if (newSubChunk2Size % 2 != 0) {
|
|
||||||
newSubChunk2Size++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.sampleRate = sampleRate;
|
|
||||||
this.byteRate = sampleRate * bitsPerSample / 8;
|
|
||||||
this.chunkSize = newSubChunk2Size + 36;
|
|
||||||
this.subChunk2Size = newSubChunk2Size;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setChunkId(String chunkId) {
|
|
||||||
this.chunkId = chunkId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setChunkSize(long chunkSize) {
|
|
||||||
this.chunkSize = chunkSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setFormat(String format) {
|
|
||||||
this.format = format;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSubChunk1Id(String subChunk1Id) {
|
|
||||||
this.subChunk1Id = subChunk1Id;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSubChunk1Size(long subChunk1Size) {
|
|
||||||
this.subChunk1Size = subChunk1Size;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setAudioFormat(int audioFormat) {
|
|
||||||
this.audioFormat = audioFormat;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setChannels(int channels) {
|
|
||||||
this.channels = channels;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setByteRate(long byteRate) {
|
|
||||||
this.byteRate = byteRate;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setBlockAlign(int blockAlign) {
|
|
||||||
this.blockAlign = blockAlign;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setBitsPerSample(int bitsPerSample) {
|
|
||||||
this.bitsPerSample = bitsPerSample;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSubChunk2Id(String subChunk2Id) {
|
|
||||||
this.subChunk2Id = subChunk2Id;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSubChunk2Size(long subChunk2Size) {
|
|
||||||
this.subChunk2Size = subChunk2Size;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String toString() {
|
|
||||||
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
sb.append("chunkId: " + chunkId);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("chunkSize: " + chunkSize);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("format: " + format);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("subChunk1Id: " + subChunk1Id);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("subChunk1Size: " + subChunk1Size);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("audioFormat: " + audioFormat);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("channels: " + channels);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("sampleRate: " + sampleRate);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("byteRate: " + byteRate);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("blockAlign: " + blockAlign);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("bitsPerSample: " + bitsPerSample);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("subChunk2Id: " + subChunk2Id);
|
|
||||||
sb.append("\n");
|
|
||||||
sb.append("subChunk2Size: " + subChunk2Size);
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,81 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.dsp;
|
|
||||||
|
|
||||||
import org.jtransforms.fft.DoubleFFT_1D;
|
|
||||||
|
|
||||||
public class FastFourierTransform {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the frequency intensities
|
|
||||||
*
|
|
||||||
* @param amplitudes amplitudes of the signal. Format depends on value of complex
|
|
||||||
* @param complex if true, amplitudes is assumed to be complex interlaced (re = even, im = odd), if false amplitudes
|
|
||||||
* are assumed to be real valued.
|
|
||||||
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
|
|
||||||
*/
|
|
||||||
public double[] getMagnitudes(double[] amplitudes, boolean complex) {
|
|
||||||
|
|
||||||
final int sampleSize = amplitudes.length;
|
|
||||||
final int nrofFrequencyBins = sampleSize / 2;
|
|
||||||
|
|
||||||
|
|
||||||
// call the fft and transform the complex numbers
|
|
||||||
if (complex) {
|
|
||||||
DoubleFFT_1D fft = new DoubleFFT_1D(nrofFrequencyBins);
|
|
||||||
fft.complexForward(amplitudes);
|
|
||||||
} else {
|
|
||||||
DoubleFFT_1D fft = new DoubleFFT_1D(sampleSize);
|
|
||||||
fft.realForward(amplitudes);
|
|
||||||
// amplitudes[1] contains re[sampleSize/2] or im[(sampleSize-1) / 2] (depending on whether sampleSize is odd or even)
|
|
||||||
// Discard it as it is useless without the other part
|
|
||||||
// im part dc bin is always 0 for real input
|
|
||||||
amplitudes[1] = 0;
|
|
||||||
}
|
|
||||||
// end call the fft and transform the complex numbers
|
|
||||||
|
|
||||||
// even indexes (0,2,4,6,...) are real parts
|
|
||||||
// odd indexes (1,3,5,7,...) are img parts
|
|
||||||
double[] mag = new double[nrofFrequencyBins];
|
|
||||||
for (int i = 0; i < nrofFrequencyBins; i++) {
|
|
||||||
final int f = 2 * i;
|
|
||||||
mag[i] = Math.sqrt(amplitudes[f] * amplitudes[f] + amplitudes[f + 1] * amplitudes[f + 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return mag;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the frequency intensities. Backwards compatible with previous versions w.r.t to number of frequency bins.
|
|
||||||
* Use getMagnitudes(amplitudes, true) to get all bins.
|
|
||||||
*
|
|
||||||
* @param amplitudes complex-valued signal to transform. Even indexes are real and odd indexes are img
|
|
||||||
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
|
|
||||||
*/
|
|
||||||
public double[] getMagnitudes(double[] amplitudes) {
|
|
||||||
double[] magnitudes = getMagnitudes(amplitudes, true);
|
|
||||||
|
|
||||||
double[] halfOfMagnitudes = new double[magnitudes.length/2];
|
|
||||||
System.arraycopy(magnitudes, 0,halfOfMagnitudes, 0, halfOfMagnitudes.length);
|
|
||||||
return halfOfMagnitudes;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,66 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.dsp;
|
|
||||||
|
|
||||||
public class LinearInterpolation {
|
|
||||||
|
|
||||||
public LinearInterpolation() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Do interpolation on the samples according to the original and destinated sample rates
|
|
||||||
*
|
|
||||||
* @param oldSampleRate sample rate of the original samples
|
|
||||||
* @param newSampleRate sample rate of the interpolated samples
|
|
||||||
* @param samples original samples
|
|
||||||
* @return interpolated samples
|
|
||||||
*/
|
|
||||||
public short[] interpolate(int oldSampleRate, int newSampleRate, short[] samples) {
|
|
||||||
|
|
||||||
if (oldSampleRate == newSampleRate) {
|
|
||||||
return samples;
|
|
||||||
}
|
|
||||||
|
|
||||||
int newLength = Math.round(((float) samples.length / oldSampleRate * newSampleRate));
|
|
||||||
float lengthMultiplier = (float) newLength / samples.length;
|
|
||||||
short[] interpolatedSamples = new short[newLength];
|
|
||||||
|
|
||||||
// interpolate the value by the linear equation y=mx+c
|
|
||||||
for (int i = 0; i < newLength; i++) {
|
|
||||||
|
|
||||||
// get the nearest positions for the interpolated point
|
|
||||||
float currentPosition = i / lengthMultiplier;
|
|
||||||
int nearestLeftPosition = (int) currentPosition;
|
|
||||||
int nearestRightPosition = nearestLeftPosition + 1;
|
|
||||||
if (nearestRightPosition >= samples.length) {
|
|
||||||
nearestRightPosition = samples.length - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
float slope = samples[nearestRightPosition] - samples[nearestLeftPosition]; // delta x is 1
|
|
||||||
float positionFromLeft = currentPosition - nearestLeftPosition;
|
|
||||||
|
|
||||||
interpolatedSamples[i] = (short) (slope * positionFromLeft + samples[nearestLeftPosition]); // y=mx+c
|
|
||||||
}
|
|
||||||
|
|
||||||
return interpolatedSamples;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,84 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.dsp;
|
|
||||||
|
|
||||||
public class Resampler {
|
|
||||||
|
|
||||||
public Resampler() {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Do resampling. Currently the amplitude is stored by short such that maximum bitsPerSample is 16 (bytePerSample is 2)
|
|
||||||
*
|
|
||||||
* @param sourceData The source data in bytes
|
|
||||||
* @param bitsPerSample How many bits represents one sample (currently supports max. bitsPerSample=16)
|
|
||||||
* @param sourceRate Sample rate of the source data
|
|
||||||
* @param targetRate Sample rate of the target data
|
|
||||||
* @return re-sampled data
|
|
||||||
*/
|
|
||||||
public byte[] reSample(byte[] sourceData, int bitsPerSample, int sourceRate, int targetRate) {
|
|
||||||
|
|
||||||
// make the bytes to amplitudes first
|
|
||||||
int bytePerSample = bitsPerSample / 8;
|
|
||||||
int numSamples = sourceData.length / bytePerSample;
|
|
||||||
short[] amplitudes = new short[numSamples]; // 16 bit, use a short to store
|
|
||||||
|
|
||||||
int pointer = 0;
|
|
||||||
for (int i = 0; i < numSamples; i++) {
|
|
||||||
short amplitude = 0;
|
|
||||||
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
|
|
||||||
// little endian
|
|
||||||
amplitude |= (short) ((sourceData[pointer++] & 0xFF) << (byteNumber * 8));
|
|
||||||
}
|
|
||||||
amplitudes[i] = amplitude;
|
|
||||||
}
|
|
||||||
// end make the amplitudes
|
|
||||||
|
|
||||||
// do interpolation
|
|
||||||
LinearInterpolation reSample = new LinearInterpolation();
|
|
||||||
short[] targetSample = reSample.interpolate(sourceRate, targetRate, amplitudes);
|
|
||||||
int targetLength = targetSample.length;
|
|
||||||
// end do interpolation
|
|
||||||
|
|
||||||
// TODO: Remove the high frequency signals with a digital filter, leaving a signal containing only half-sample-rated frequency information, but still sampled at a rate of target sample rate. Usually FIR is used
|
|
||||||
|
|
||||||
// end resample the amplitudes
|
|
||||||
|
|
||||||
// convert the amplitude to bytes
|
|
||||||
byte[] bytes;
|
|
||||||
if (bytePerSample == 1) {
|
|
||||||
bytes = new byte[targetLength];
|
|
||||||
for (int i = 0; i < targetLength; i++) {
|
|
||||||
bytes[i] = (byte) targetSample[i];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// suppose bytePerSample==2
|
|
||||||
bytes = new byte[targetLength * 2];
|
|
||||||
for (int i = 0; i < targetSample.length; i++) {
|
|
||||||
// little endian
|
|
||||||
bytes[i * 2] = (byte) (targetSample[i] & 0xff);
|
|
||||||
bytes[i * 2 + 1] = (byte) ((targetSample[i] >> 8) & 0xff);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end convert the amplitude to bytes
|
|
||||||
|
|
||||||
return bytes;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.dsp;
|
|
||||||
|
|
||||||
public class WindowFunction {
|
|
||||||
|
|
||||||
public static final int RECTANGULAR = 0;
|
|
||||||
public static final int BARTLETT = 1;
|
|
||||||
public static final int HANNING = 2;
|
|
||||||
public static final int HAMMING = 3;
|
|
||||||
public static final int BLACKMAN = 4;
|
|
||||||
|
|
||||||
int windowType = 0; // defaults to rectangular window
|
|
||||||
|
|
||||||
public WindowFunction() {}
|
|
||||||
|
|
||||||
public void setWindowType(int wt) {
|
|
||||||
windowType = wt;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setWindowType(String w) {
|
|
||||||
if (w.toUpperCase().equals("RECTANGULAR"))
|
|
||||||
windowType = RECTANGULAR;
|
|
||||||
if (w.toUpperCase().equals("BARTLETT"))
|
|
||||||
windowType = BARTLETT;
|
|
||||||
if (w.toUpperCase().equals("HANNING"))
|
|
||||||
windowType = HANNING;
|
|
||||||
if (w.toUpperCase().equals("HAMMING"))
|
|
||||||
windowType = HAMMING;
|
|
||||||
if (w.toUpperCase().equals("BLACKMAN"))
|
|
||||||
windowType = BLACKMAN;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getWindowType() {
|
|
||||||
return windowType;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a window
|
|
||||||
*
|
|
||||||
* @param nSamples size of the window
|
|
||||||
* @return window in array
|
|
||||||
*/
|
|
||||||
public double[] generate(int nSamples) {
|
|
||||||
// generate nSamples window function values
|
|
||||||
// for index values 0 .. nSamples - 1
|
|
||||||
int m = nSamples / 2;
|
|
||||||
double r;
|
|
||||||
double pi = Math.PI;
|
|
||||||
double[] w = new double[nSamples];
|
|
||||||
switch (windowType) {
|
|
||||||
case BARTLETT: // Bartlett (triangular) window
|
|
||||||
for (int n = 0; n < nSamples; n++)
|
|
||||||
w[n] = 1.0f - Math.abs(n - m) / m;
|
|
||||||
break;
|
|
||||||
case HANNING: // Hanning window
|
|
||||||
r = pi / (m + 1);
|
|
||||||
for (int n = -m; n < m; n++)
|
|
||||||
w[m + n] = 0.5f + 0.5f * Math.cos(n * r);
|
|
||||||
break;
|
|
||||||
case HAMMING: // Hamming window
|
|
||||||
r = pi / m;
|
|
||||||
for (int n = -m; n < m; n++)
|
|
||||||
w[m + n] = 0.54f + 0.46f * Math.cos(n * r);
|
|
||||||
break;
|
|
||||||
case BLACKMAN: // Blackman window
|
|
||||||
r = pi / m;
|
|
||||||
for (int n = -m; n < m; n++)
|
|
||||||
w[m + n] = 0.42f + 0.5f * Math.cos(n * r) + 0.08f * Math.cos(2 * n * r);
|
|
||||||
break;
|
|
||||||
default: // Rectangular window function
|
|
||||||
for (int n = 0; n < nSamples; n++)
|
|
||||||
w[n] = 1.0f;
|
|
||||||
}
|
|
||||||
return w;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.dsp;
|
|
|
@ -1,67 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.extension;
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.audio.Wave;
|
|
||||||
|
|
||||||
public class NormalizedSampleAmplitudes {
|
|
||||||
|
|
||||||
private Wave wave;
|
|
||||||
private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame
|
|
||||||
|
|
||||||
public NormalizedSampleAmplitudes(Wave wave) {
|
|
||||||
this.wave = wave;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* Get normalized amplitude of each frame
|
|
||||||
*
|
|
||||||
* @return array of normalized amplitudes(signed 16 bit): normalizedAmplitudes[frame]=amplitude
|
|
||||||
*/
|
|
||||||
public double[] getNormalizedAmplitudes() {
|
|
||||||
|
|
||||||
if (normalizedAmplitudes == null) {
|
|
||||||
|
|
||||||
boolean signed = true;
|
|
||||||
|
|
||||||
// usually 8bit is unsigned
|
|
||||||
if (wave.getWaveHeader().getBitsPerSample() == 8) {
|
|
||||||
signed = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
short[] amplitudes = wave.getSampleAmplitudes();
|
|
||||||
int numSamples = amplitudes.length;
|
|
||||||
int maxAmplitude = 1 << (wave.getWaveHeader().getBitsPerSample() - 1);
|
|
||||||
|
|
||||||
if (!signed) { // one more bit for unsigned value
|
|
||||||
maxAmplitude <<= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedAmplitudes = new double[numSamples];
|
|
||||||
for (int i = 0; i < numSamples; i++) {
|
|
||||||
normalizedAmplitudes[i] = (double) amplitudes[i] / maxAmplitude;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return normalizedAmplitudes;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,214 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.extension;
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.audio.Wave;
|
|
||||||
import org.datavec.audio.dsp.FastFourierTransform;
|
|
||||||
import org.datavec.audio.dsp.WindowFunction;
|
|
||||||
|
|
||||||
public class Spectrogram {
|
|
||||||
|
|
||||||
public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024;
|
|
||||||
public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping
|
|
||||||
|
|
||||||
private Wave wave;
|
|
||||||
private double[][] spectrogram; // relative spectrogram
|
|
||||||
private double[][] absoluteSpectrogram; // absolute spectrogram
|
|
||||||
private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2
|
|
||||||
private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping
|
|
||||||
private int numFrames; // number of frames of the spectrogram
|
|
||||||
private int framesPerSecond; // frame per second of the spectrogram
|
|
||||||
private int numFrequencyUnit; // number of y-axis unit
|
|
||||||
private double unitFrequency; // frequency per y-axis unit
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*
|
|
||||||
* @param wave
|
|
||||||
*/
|
|
||||||
public Spectrogram(Wave wave) {
|
|
||||||
this.wave = wave;
|
|
||||||
// default
|
|
||||||
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
|
|
||||||
this.overlapFactor = SPECTROGRAM_DEFAULT_OVERLAP_FACTOR;
|
|
||||||
buildSpectrogram();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*
|
|
||||||
* @param wave
|
|
||||||
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
|
|
||||||
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
|
|
||||||
*/
|
|
||||||
public Spectrogram(Wave wave, int fftSampleSize, int overlapFactor) {
|
|
||||||
this.wave = wave;
|
|
||||||
|
|
||||||
if (Integer.bitCount(fftSampleSize) == 1) {
|
|
||||||
this.fftSampleSize = fftSampleSize;
|
|
||||||
} else {
|
|
||||||
System.err.print("The input number must be a power of 2");
|
|
||||||
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.overlapFactor = overlapFactor;
|
|
||||||
|
|
||||||
buildSpectrogram();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build spectrogram
|
|
||||||
*/
|
|
||||||
private void buildSpectrogram() {
|
|
||||||
|
|
||||||
short[] amplitudes = wave.getSampleAmplitudes();
|
|
||||||
int numSamples = amplitudes.length;
|
|
||||||
|
|
||||||
int pointer = 0;
|
|
||||||
// overlapping
|
|
||||||
if (overlapFactor > 1) {
|
|
||||||
int numOverlappedSamples = numSamples * overlapFactor;
|
|
||||||
int backSamples = fftSampleSize * (overlapFactor - 1) / overlapFactor;
|
|
||||||
short[] overlapAmp = new short[numOverlappedSamples];
|
|
||||||
pointer = 0;
|
|
||||||
for (int i = 0; i < amplitudes.length; i++) {
|
|
||||||
overlapAmp[pointer++] = amplitudes[i];
|
|
||||||
if (pointer % fftSampleSize == 0) {
|
|
||||||
// overlap
|
|
||||||
i -= backSamples;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
numSamples = numOverlappedSamples;
|
|
||||||
amplitudes = overlapAmp;
|
|
||||||
}
|
|
||||||
// end overlapping
|
|
||||||
|
|
||||||
numFrames = numSamples / fftSampleSize;
|
|
||||||
framesPerSecond = (int) (numFrames / wave.length());
|
|
||||||
|
|
||||||
// set signals for fft
|
|
||||||
WindowFunction window = new WindowFunction();
|
|
||||||
window.setWindowType("Hamming");
|
|
||||||
double[] win = window.generate(fftSampleSize);
|
|
||||||
|
|
||||||
double[][] signals = new double[numFrames][];
|
|
||||||
for (int f = 0; f < numFrames; f++) {
|
|
||||||
signals[f] = new double[fftSampleSize];
|
|
||||||
int startSample = f * fftSampleSize;
|
|
||||||
for (int n = 0; n < fftSampleSize; n++) {
|
|
||||||
signals[f][n] = amplitudes[startSample + n] * win[n];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end set signals for fft
|
|
||||||
|
|
||||||
absoluteSpectrogram = new double[numFrames][];
|
|
||||||
// for each frame in signals, do fft on it
|
|
||||||
FastFourierTransform fft = new FastFourierTransform();
|
|
||||||
for (int i = 0; i < numFrames; i++) {
|
|
||||||
absoluteSpectrogram[i] = fft.getMagnitudes(signals[i], false);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (absoluteSpectrogram.length > 0) {
|
|
||||||
|
|
||||||
numFrequencyUnit = absoluteSpectrogram[0].length;
|
|
||||||
unitFrequency = (double) wave.getWaveHeader().getSampleRate() / 2 / numFrequencyUnit; // frequency could be caught within the half of nSamples according to Nyquist theory
|
|
||||||
|
|
||||||
// normalization of absoultSpectrogram
|
|
||||||
spectrogram = new double[numFrames][numFrequencyUnit];
|
|
||||||
|
|
||||||
// set max and min amplitudes
|
|
||||||
double maxAmp = Double.MIN_VALUE;
|
|
||||||
double minAmp = Double.MAX_VALUE;
|
|
||||||
for (int i = 0; i < numFrames; i++) {
|
|
||||||
for (int j = 0; j < numFrequencyUnit; j++) {
|
|
||||||
if (absoluteSpectrogram[i][j] > maxAmp) {
|
|
||||||
maxAmp = absoluteSpectrogram[i][j];
|
|
||||||
} else if (absoluteSpectrogram[i][j] < minAmp) {
|
|
||||||
minAmp = absoluteSpectrogram[i][j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end set max and min amplitudes
|
|
||||||
|
|
||||||
// normalization
|
|
||||||
// avoiding divided by zero
|
|
||||||
double minValidAmp = 0.00000000001F;
|
|
||||||
if (minAmp == 0) {
|
|
||||||
minAmp = minValidAmp;
|
|
||||||
}
|
|
||||||
|
|
||||||
double diff = Math.log10(maxAmp / minAmp); // perceptual difference
|
|
||||||
for (int i = 0; i < numFrames; i++) {
|
|
||||||
for (int j = 0; j < numFrequencyUnit; j++) {
|
|
||||||
if (absoluteSpectrogram[i][j] < minValidAmp) {
|
|
||||||
spectrogram[i][j] = 0;
|
|
||||||
} else {
|
|
||||||
spectrogram[i][j] = (Math.log10(absoluteSpectrogram[i][j] / minAmp)) / diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end normalization
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get spectrogram: spectrogram[time][frequency]=intensity
|
|
||||||
*
|
|
||||||
* @return logarithm normalized spectrogram
|
|
||||||
*/
|
|
||||||
public double[][] getNormalizedSpectrogramData() {
|
|
||||||
return spectrogram;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get spectrogram: spectrogram[time][frequency]=intensity
|
|
||||||
*
|
|
||||||
* @return absolute spectrogram
|
|
||||||
*/
|
|
||||||
public double[][] getAbsoluteSpectrogramData() {
|
|
||||||
return absoluteSpectrogram;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumFrames() {
|
|
||||||
return numFrames;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getFramesPerSecond() {
|
|
||||||
return framesPerSecond;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumFrequencyUnit() {
|
|
||||||
return numFrequencyUnit;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getUnitFrequency() {
|
|
||||||
return unitFrequency;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getFftSampleSize() {
|
|
||||||
return fftSampleSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getOverlapFactor() {
|
|
||||||
return overlapFactor;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,272 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.datavec.audio.Wave;
|
|
||||||
import org.datavec.audio.WaveHeader;
|
|
||||||
import org.datavec.audio.dsp.Resampler;
|
|
||||||
import org.datavec.audio.extension.Spectrogram;
|
|
||||||
import org.datavec.audio.processor.TopManyPointsProcessorChain;
|
|
||||||
import org.datavec.audio.properties.FingerprintProperties;
|
|
||||||
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FileOutputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FingerprintManager {
|
|
||||||
|
|
||||||
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
|
||||||
private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame();
|
|
||||||
private int overlapFactor = fingerprintProperties.getOverlapFactor();
|
|
||||||
private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame();
|
|
||||||
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*/
|
|
||||||
public FingerprintManager() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract fingerprint from Wave object
|
|
||||||
*
|
|
||||||
* @param wave Wave Object to be extracted fingerprint
|
|
||||||
* @return fingerprint in bytes
|
|
||||||
*/
|
|
||||||
public byte[] extractFingerprint(Wave wave) {
|
|
||||||
|
|
||||||
int[][] coordinates; // coordinates[x][0..3]=y0..y3
|
|
||||||
byte[] fingerprint = new byte[0];
|
|
||||||
|
|
||||||
// resample to target rate
|
|
||||||
Resampler resampler = new Resampler();
|
|
||||||
int sourceRate = wave.getWaveHeader().getSampleRate();
|
|
||||||
int targetRate = fingerprintProperties.getSampleRate();
|
|
||||||
|
|
||||||
byte[] resampledWaveData = resampler.reSample(wave.getBytes(), wave.getWaveHeader().getBitsPerSample(),
|
|
||||||
sourceRate, targetRate);
|
|
||||||
|
|
||||||
// update the wave header
|
|
||||||
WaveHeader resampledWaveHeader = wave.getWaveHeader();
|
|
||||||
resampledWaveHeader.setSampleRate(targetRate);
|
|
||||||
|
|
||||||
// make resampled wave
|
|
||||||
Wave resampledWave = new Wave(resampledWaveHeader, resampledWaveData);
|
|
||||||
// end resample to target rate
|
|
||||||
|
|
||||||
// get spectrogram's data
|
|
||||||
Spectrogram spectrogram = resampledWave.getSpectrogram(sampleSizePerFrame, overlapFactor);
|
|
||||||
double[][] spectorgramData = spectrogram.getNormalizedSpectrogramData();
|
|
||||||
|
|
||||||
List<Integer>[] pointsLists = getRobustPointList(spectorgramData);
|
|
||||||
int numFrames = pointsLists.length;
|
|
||||||
|
|
||||||
// prepare fingerprint bytes
|
|
||||||
coordinates = new int[numFrames][numRobustPointsPerFrame];
|
|
||||||
|
|
||||||
for (int x = 0; x < numFrames; x++) {
|
|
||||||
if (pointsLists[x].size() == numRobustPointsPerFrame) {
|
|
||||||
Iterator<Integer> pointsListsIterator = pointsLists[x].iterator();
|
|
||||||
for (int y = 0; y < numRobustPointsPerFrame; y++) {
|
|
||||||
coordinates[x][y] = pointsListsIterator.next();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// use -1 to fill the empty byte
|
|
||||||
for (int y = 0; y < numRobustPointsPerFrame; y++) {
|
|
||||||
coordinates[x][y] = -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end make fingerprint
|
|
||||||
|
|
||||||
// for each valid coordinate, append with its intensity
|
|
||||||
List<Byte> byteList = new LinkedList<Byte>();
|
|
||||||
for (int i = 0; i < numFrames; i++) {
|
|
||||||
for (int j = 0; j < numRobustPointsPerFrame; j++) {
|
|
||||||
if (coordinates[i][j] != -1) {
|
|
||||||
// first 2 bytes is x
|
|
||||||
byteList.add((byte) (i >> 8));
|
|
||||||
byteList.add((byte) i);
|
|
||||||
|
|
||||||
// next 2 bytes is y
|
|
||||||
int y = coordinates[i][j];
|
|
||||||
byteList.add((byte) (y >> 8));
|
|
||||||
byteList.add((byte) y);
|
|
||||||
|
|
||||||
// next 4 bytes is intensity
|
|
||||||
int intensity = (int) (spectorgramData[i][y] * Integer.MAX_VALUE); // spectorgramData is ranged from 0~1
|
|
||||||
byteList.add((byte) (intensity >> 24));
|
|
||||||
byteList.add((byte) (intensity >> 16));
|
|
||||||
byteList.add((byte) (intensity >> 8));
|
|
||||||
byteList.add((byte) intensity);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end for each valid coordinate, append with its intensity
|
|
||||||
|
|
||||||
fingerprint = new byte[byteList.size()];
|
|
||||||
Iterator<Byte> byteListIterator = byteList.iterator();
|
|
||||||
int pointer = 0;
|
|
||||||
while (byteListIterator.hasNext()) {
|
|
||||||
fingerprint[pointer++] = byteListIterator.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
return fingerprint;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get bytes from fingerprint file
|
|
||||||
*
|
|
||||||
* @param fingerprintFile fingerprint filename
|
|
||||||
* @return fingerprint in bytes
|
|
||||||
*/
|
|
||||||
public byte[] getFingerprintFromFile(String fingerprintFile) {
|
|
||||||
byte[] fingerprint = null;
|
|
||||||
try {
|
|
||||||
InputStream fis = new FileInputStream(fingerprintFile);
|
|
||||||
fingerprint = getFingerprintFromInputStream(fis);
|
|
||||||
fis.close();
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
return fingerprint;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get bytes from fingerprint inputstream
|
|
||||||
*
|
|
||||||
* @param inputStream fingerprint inputstream
|
|
||||||
* @return fingerprint in bytes
|
|
||||||
*/
|
|
||||||
public byte[] getFingerprintFromInputStream(InputStream inputStream) {
|
|
||||||
byte[] fingerprint = null;
|
|
||||||
try {
|
|
||||||
fingerprint = new byte[inputStream.available()];
|
|
||||||
inputStream.read(fingerprint);
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
return fingerprint;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save fingerprint to a file
|
|
||||||
*
|
|
||||||
* @param fingerprint fingerprint bytes
|
|
||||||
* @param filename fingerprint filename
|
|
||||||
* @see FingerprintManager file saved
|
|
||||||
*/
|
|
||||||
public void saveFingerprintAsFile(byte[] fingerprint, String filename) {
|
|
||||||
|
|
||||||
FileOutputStream fileOutputStream;
|
|
||||||
try {
|
|
||||||
fileOutputStream = new FileOutputStream(filename);
|
|
||||||
fileOutputStream.write(fingerprint);
|
|
||||||
fileOutputStream.close();
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// robustLists[x]=y1,y2,y3,...
|
|
||||||
private List<Integer>[] getRobustPointList(double[][] spectrogramData) {
|
|
||||||
|
|
||||||
int numX = spectrogramData.length;
|
|
||||||
int numY = spectrogramData[0].length;
|
|
||||||
|
|
||||||
double[][] allBanksIntensities = new double[numX][numY];
|
|
||||||
int bandwidthPerBank = numY / numFilterBanks;
|
|
||||||
|
|
||||||
for (int b = 0; b < numFilterBanks; b++) {
|
|
||||||
|
|
||||||
double[][] bankIntensities = new double[numX][bandwidthPerBank];
|
|
||||||
|
|
||||||
for (int i = 0; i < numX; i++) {
|
|
||||||
System.arraycopy(spectrogramData[i], b * bandwidthPerBank, bankIntensities[i], 0, bandwidthPerBank);
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the most robust point in each filter bank
|
|
||||||
TopManyPointsProcessorChain processorChain = new TopManyPointsProcessorChain(bankIntensities, 1);
|
|
||||||
double[][] processedIntensities = processorChain.getIntensities();
|
|
||||||
|
|
||||||
for (int i = 0; i < numX; i++) {
|
|
||||||
System.arraycopy(processedIntensities[i], 0, allBanksIntensities[i], b * bandwidthPerBank,
|
|
||||||
bandwidthPerBank);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
List<int[]> robustPointList = new LinkedList<int[]>();
|
|
||||||
|
|
||||||
// find robust points
|
|
||||||
for (int i = 0; i < allBanksIntensities.length; i++) {
|
|
||||||
for (int j = 0; j < allBanksIntensities[i].length; j++) {
|
|
||||||
if (allBanksIntensities[i][j] > 0) {
|
|
||||||
|
|
||||||
int[] point = new int[] {i, j};
|
|
||||||
//System.out.println(i+","+frequency);
|
|
||||||
robustPointList.add(point);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end find robust points
|
|
||||||
|
|
||||||
List<Integer>[] robustLists = new LinkedList[spectrogramData.length];
|
|
||||||
for (int i = 0; i < robustLists.length; i++) {
|
|
||||||
robustLists[i] = new LinkedList<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// robustLists[x]=y1,y2,y3,...
|
|
||||||
for (int[] coor : robustPointList) {
|
|
||||||
robustLists[coor[0]].add(coor[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// return the list per frame
|
|
||||||
return robustLists;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of frames in a fingerprint
|
|
||||||
* Each frame lengths 8 bytes
|
|
||||||
* Usually there is more than one point in each frame, so it cannot simply divide the bytes length by 8
|
|
||||||
* Last 8 byte of thisFingerprint is the last frame of this wave
|
|
||||||
* First 2 byte of the last 8 byte is the x position of this wave, i.e. (number_of_frames-1) of this wave
|
|
||||||
*
|
|
||||||
* @param fingerprint fingerprint bytes
|
|
||||||
* @return number of frames of the fingerprint
|
|
||||||
*/
|
|
||||||
public static int getNumFrames(byte[] fingerprint) {
|
|
||||||
|
|
||||||
if (fingerprint.length < 8) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the last x-coordinate (length-8&length-7)bytes from fingerprint
|
|
||||||
return ((fingerprint[fingerprint.length - 8] & 0xff) << 8 | (fingerprint[fingerprint.length - 7] & 0xff)) + 1;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,107 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.audio.properties.FingerprintProperties;
|
|
||||||
|
|
||||||
public class FingerprintSimilarity {
|
|
||||||
|
|
||||||
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
|
||||||
private int mostSimilarFramePosition;
|
|
||||||
private float score;
|
|
||||||
private float similarity;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*/
|
|
||||||
public FingerprintSimilarity() {
|
|
||||||
mostSimilarFramePosition = Integer.MIN_VALUE;
|
|
||||||
score = -1;
|
|
||||||
similarity = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the most similar position in terms of frame number
|
|
||||||
*
|
|
||||||
* @return most similar frame position
|
|
||||||
*/
|
|
||||||
public int getMostSimilarFramePosition() {
|
|
||||||
return mostSimilarFramePosition;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the most similar position in terms of frame number
|
|
||||||
*
|
|
||||||
* @param mostSimilarFramePosition
|
|
||||||
*/
|
|
||||||
public void setMostSimilarFramePosition(int mostSimilarFramePosition) {
|
|
||||||
this.mostSimilarFramePosition = mostSimilarFramePosition;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the similarity of the fingerprints
|
|
||||||
* similarity from 0~1, which 0 means no similar feature is found and 1 means in average there is at least one match in every frame
|
|
||||||
*
|
|
||||||
* @return fingerprints similarity
|
|
||||||
*/
|
|
||||||
public float getSimilarity() {
|
|
||||||
return similarity;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the similarity of the fingerprints
|
|
||||||
*
|
|
||||||
* @param similarity similarity
|
|
||||||
*/
|
|
||||||
public void setSimilarity(float similarity) {
|
|
||||||
this.similarity = similarity;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the similarity score of the fingerprints
|
|
||||||
* Number of features found in the fingerprints per frame
|
|
||||||
*
|
|
||||||
* @return fingerprints similarity score
|
|
||||||
*/
|
|
||||||
public float getScore() {
|
|
||||||
return score;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the similarity score of the fingerprints
|
|
||||||
*
|
|
||||||
* @param score
|
|
||||||
*/
|
|
||||||
public void setScore(float score) {
|
|
||||||
this.score = score;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the most similar position in terms of time in second
|
|
||||||
*
|
|
||||||
* @return most similar starting time
|
|
||||||
*/
|
|
||||||
public float getsetMostSimilarTimePosition() {
|
|
||||||
return (float) mostSimilarFramePosition / fingerprintProperties.getNumRobustPointsPerFrame()
|
|
||||||
/ fingerprintProperties.getFps();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,134 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class FingerprintSimilarityComputer {
|
|
||||||
|
|
||||||
private FingerprintSimilarity fingerprintSimilarity;
|
|
||||||
byte[] fingerprint1, fingerprint2;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor, ready to compute the similarity of two fingerprints
|
|
||||||
*
|
|
||||||
* @param fingerprint1
|
|
||||||
* @param fingerprint2
|
|
||||||
*/
|
|
||||||
public FingerprintSimilarityComputer(byte[] fingerprint1, byte[] fingerprint2) {
|
|
||||||
|
|
||||||
this.fingerprint1 = fingerprint1;
|
|
||||||
this.fingerprint2 = fingerprint2;
|
|
||||||
|
|
||||||
fingerprintSimilarity = new FingerprintSimilarity();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get fingerprint similarity of inout fingerprints
|
|
||||||
*
|
|
||||||
* @return fingerprint similarity object
|
|
||||||
*/
|
|
||||||
public FingerprintSimilarity getFingerprintsSimilarity() {
|
|
||||||
HashMap<Integer, Integer> offset_Score_Table = new HashMap<>(); // offset_Score_Table<offset,count>
|
|
||||||
int numFrames;
|
|
||||||
float score = 0;
|
|
||||||
int mostSimilarFramePosition = Integer.MIN_VALUE;
|
|
||||||
|
|
||||||
// one frame may contain several points, use the shorter one be the denominator
|
|
||||||
if (fingerprint1.length > fingerprint2.length) {
|
|
||||||
numFrames = FingerprintManager.getNumFrames(fingerprint2);
|
|
||||||
} else {
|
|
||||||
numFrames = FingerprintManager.getNumFrames(fingerprint1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the pairs
|
|
||||||
PairManager pairManager = new PairManager();
|
|
||||||
HashMap<Integer, List<Integer>> this_Pair_PositionList_Table =
|
|
||||||
pairManager.getPair_PositionList_Table(fingerprint1);
|
|
||||||
HashMap<Integer, List<Integer>> compareWave_Pair_PositionList_Table =
|
|
||||||
pairManager.getPair_PositionList_Table(fingerprint2);
|
|
||||||
|
|
||||||
for (Integer compareWaveHashNumber : compareWave_Pair_PositionList_Table.keySet()) {
|
|
||||||
// if the compareWaveHashNumber doesn't exist in both tables, no need to compare
|
|
||||||
if (!this_Pair_PositionList_Table.containsKey(compareWaveHashNumber)
|
|
||||||
|| !compareWave_Pair_PositionList_Table.containsKey(compareWaveHashNumber)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// for each compare hash number, get the positions
|
|
||||||
List<Integer> wavePositionList = this_Pair_PositionList_Table.get(compareWaveHashNumber);
|
|
||||||
List<Integer> compareWavePositionList = compareWave_Pair_PositionList_Table.get(compareWaveHashNumber);
|
|
||||||
|
|
||||||
for (Integer thisPosition : wavePositionList) {
|
|
||||||
for (Integer compareWavePosition : compareWavePositionList) {
|
|
||||||
int offset = thisPosition - compareWavePosition;
|
|
||||||
if (offset_Score_Table.containsKey(offset)) {
|
|
||||||
offset_Score_Table.put(offset, offset_Score_Table.get(offset) + 1);
|
|
||||||
} else {
|
|
||||||
offset_Score_Table.put(offset, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// map rank
|
|
||||||
MapRank mapRank = new MapRankInteger(offset_Score_Table, false);
|
|
||||||
|
|
||||||
// get the most similar positions and scores
|
|
||||||
List<Integer> orderedKeyList = mapRank.getOrderedKeyList(100, true);
|
|
||||||
if (orderedKeyList.size() > 0) {
|
|
||||||
int key = orderedKeyList.get(0);
|
|
||||||
// get the highest score position
|
|
||||||
mostSimilarFramePosition = key;
|
|
||||||
score = offset_Score_Table.get(key);
|
|
||||||
|
|
||||||
// accumulate the scores from neighbours
|
|
||||||
if (offset_Score_Table.containsKey(key - 1)) {
|
|
||||||
score += offset_Score_Table.get(key - 1) / 2;
|
|
||||||
}
|
|
||||||
if (offset_Score_Table.containsKey(key + 1)) {
|
|
||||||
score += offset_Score_Table.get(key + 1) / 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Iterator<Integer> orderedKeyListIterator=orderedKeyList.iterator();
|
|
||||||
while (orderedKeyListIterator.hasNext()){
|
|
||||||
int offset=orderedKeyListIterator.next();
|
|
||||||
System.out.println(offset+": "+offset_Score_Table.get(offset));
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
score /= numFrames;
|
|
||||||
float similarity = score;
|
|
||||||
// similarity >1 means in average there is at least one match in every frame
|
|
||||||
if (similarity > 1) {
|
|
||||||
similarity = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
fingerprintSimilarity.setMostSimilarFramePosition(mostSimilarFramePosition);
|
|
||||||
fingerprintSimilarity.setScore(score);
|
|
||||||
fingerprintSimilarity.setSimilarity(similarity);
|
|
||||||
|
|
||||||
return fingerprintSimilarity;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,27 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface MapRank {
|
|
||||||
public List getOrderedKeyList(int numKeys, boolean sharpLimit);
|
|
||||||
}
|
|
|
@ -1,179 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.Map.Entry;
|
|
||||||
|
|
||||||
public class MapRankDouble implements MapRank {
|
|
||||||
|
|
||||||
private Map map;
|
|
||||||
private boolean acsending = true;
|
|
||||||
|
|
||||||
public MapRankDouble(Map<?, Double> map, boolean acsending) {
|
|
||||||
this.map = map;
|
|
||||||
this.acsending = acsending;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
|
|
||||||
|
|
||||||
Set mapEntrySet = map.entrySet();
|
|
||||||
List keyList = new LinkedList();
|
|
||||||
|
|
||||||
// if the numKeys is larger than map size, limit it
|
|
||||||
if (numKeys > map.size()) {
|
|
||||||
numKeys = map.size();
|
|
||||||
}
|
|
||||||
// end if the numKeys is larger than map size, limit it
|
|
||||||
|
|
||||||
if (map.size() > 0) {
|
|
||||||
double[] array = new double[map.size()];
|
|
||||||
int count = 0;
|
|
||||||
|
|
||||||
// get the pass values
|
|
||||||
Iterator<Entry> mapIterator = mapEntrySet.iterator();
|
|
||||||
while (mapIterator.hasNext()) {
|
|
||||||
Entry entry = mapIterator.next();
|
|
||||||
array[count++] = (Double) entry.getValue();
|
|
||||||
}
|
|
||||||
// end get the pass values
|
|
||||||
|
|
||||||
int targetindex;
|
|
||||||
if (acsending) {
|
|
||||||
targetindex = numKeys;
|
|
||||||
} else {
|
|
||||||
targetindex = array.length - numKeys;
|
|
||||||
}
|
|
||||||
|
|
||||||
double passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
|
|
||||||
// get the passed keys and values
|
|
||||||
Map passedMap = new HashMap();
|
|
||||||
List<Double> valueList = new LinkedList<Double>();
|
|
||||||
mapIterator = mapEntrySet.iterator();
|
|
||||||
|
|
||||||
while (mapIterator.hasNext()) {
|
|
||||||
Entry entry = mapIterator.next();
|
|
||||||
double value = (Double) entry.getValue();
|
|
||||||
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
|
|
||||||
passedMap.put(entry.getKey(), value);
|
|
||||||
valueList.add(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end get the passed keys and values
|
|
||||||
|
|
||||||
// sort the value list
|
|
||||||
Double[] listArr = new Double[valueList.size()];
|
|
||||||
valueList.toArray(listArr);
|
|
||||||
Arrays.sort(listArr);
|
|
||||||
// end sort the value list
|
|
||||||
|
|
||||||
// get the list of keys
|
|
||||||
int resultCount = 0;
|
|
||||||
int index;
|
|
||||||
if (acsending) {
|
|
||||||
index = 0;
|
|
||||||
} else {
|
|
||||||
index = listArr.length - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sharpLimit) {
|
|
||||||
numKeys = listArr.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
double targetValue = (Double) listArr[index];
|
|
||||||
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
|
|
||||||
while (passedMapIterator.hasNext()) {
|
|
||||||
Entry entry = passedMapIterator.next();
|
|
||||||
if ((Double) entry.getValue() == targetValue) {
|
|
||||||
keyList.add(entry.getKey());
|
|
||||||
passedMapIterator.remove();
|
|
||||||
resultCount++;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (acsending) {
|
|
||||||
index++;
|
|
||||||
} else {
|
|
||||||
index--;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (resultCount >= numKeys) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end get the list of keys
|
|
||||||
}
|
|
||||||
|
|
||||||
return keyList;
|
|
||||||
}
|
|
||||||
|
|
||||||
private double getOrderedValue(double[] array, int index) {
|
|
||||||
locate(array, 0, array.length - 1, index);
|
|
||||||
return array[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort the partitions by quick sort, and locate the target index
|
|
||||||
private void locate(double[] array, int left, int right, int index) {
|
|
||||||
|
|
||||||
int mid = (left + right) / 2;
|
|
||||||
//System.out.println(left+" to "+right+" ("+mid+")");
|
|
||||||
|
|
||||||
if (right == left) {
|
|
||||||
//System.out.println("* "+array[targetIndex]);
|
|
||||||
//result=array[targetIndex];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (left < right) {
|
|
||||||
double s = array[mid];
|
|
||||||
int i = left - 1;
|
|
||||||
int j = right + 1;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
while (array[++i] < s);
|
|
||||||
while (array[--j] > s);
|
|
||||||
if (i >= j)
|
|
||||||
break;
|
|
||||||
swap(array, i, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
|
|
||||||
|
|
||||||
if (i > index) {
|
|
||||||
// the target index in the left partition
|
|
||||||
//System.out.println("left partition");
|
|
||||||
locate(array, left, i - 1, index);
|
|
||||||
} else {
|
|
||||||
// the target index in the right partition
|
|
||||||
//System.out.println("right partition");
|
|
||||||
locate(array, j + 1, right, index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void swap(double[] array, int i, int j) {
|
|
||||||
double t = array[i];
|
|
||||||
array[i] = array[j];
|
|
||||||
array[j] = t;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,179 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.Map.Entry;
|
|
||||||
|
|
||||||
public class MapRankInteger implements MapRank {
|
|
||||||
|
|
||||||
private Map map;
|
|
||||||
private boolean acsending = true;
|
|
||||||
|
|
||||||
public MapRankInteger(Map<?, Integer> map, boolean acsending) {
|
|
||||||
this.map = map;
|
|
||||||
this.acsending = acsending;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
|
|
||||||
|
|
||||||
Set mapEntrySet = map.entrySet();
|
|
||||||
List keyList = new LinkedList();
|
|
||||||
|
|
||||||
// if the numKeys is larger than map size, limit it
|
|
||||||
if (numKeys > map.size()) {
|
|
||||||
numKeys = map.size();
|
|
||||||
}
|
|
||||||
// end if the numKeys is larger than map size, limit it
|
|
||||||
|
|
||||||
if (map.size() > 0) {
|
|
||||||
int[] array = new int[map.size()];
|
|
||||||
int count = 0;
|
|
||||||
|
|
||||||
// get the pass values
|
|
||||||
Iterator<Entry> mapIterator = mapEntrySet.iterator();
|
|
||||||
while (mapIterator.hasNext()) {
|
|
||||||
Entry entry = mapIterator.next();
|
|
||||||
array[count++] = (Integer) entry.getValue();
|
|
||||||
}
|
|
||||||
// end get the pass values
|
|
||||||
|
|
||||||
int targetindex;
|
|
||||||
if (acsending) {
|
|
||||||
targetindex = numKeys;
|
|
||||||
} else {
|
|
||||||
targetindex = array.length - numKeys;
|
|
||||||
}
|
|
||||||
|
|
||||||
int passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
|
|
||||||
// get the passed keys and values
|
|
||||||
Map passedMap = new HashMap();
|
|
||||||
List<Integer> valueList = new LinkedList<Integer>();
|
|
||||||
mapIterator = mapEntrySet.iterator();
|
|
||||||
|
|
||||||
while (mapIterator.hasNext()) {
|
|
||||||
Entry entry = mapIterator.next();
|
|
||||||
int value = (Integer) entry.getValue();
|
|
||||||
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
|
|
||||||
passedMap.put(entry.getKey(), value);
|
|
||||||
valueList.add(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end get the passed keys and values
|
|
||||||
|
|
||||||
// sort the value list
|
|
||||||
Integer[] listArr = new Integer[valueList.size()];
|
|
||||||
valueList.toArray(listArr);
|
|
||||||
Arrays.sort(listArr);
|
|
||||||
// end sort the value list
|
|
||||||
|
|
||||||
// get the list of keys
|
|
||||||
int resultCount = 0;
|
|
||||||
int index;
|
|
||||||
if (acsending) {
|
|
||||||
index = 0;
|
|
||||||
} else {
|
|
||||||
index = listArr.length - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sharpLimit) {
|
|
||||||
numKeys = listArr.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
int targetValue = (Integer) listArr[index];
|
|
||||||
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
|
|
||||||
while (passedMapIterator.hasNext()) {
|
|
||||||
Entry entry = passedMapIterator.next();
|
|
||||||
if ((Integer) entry.getValue() == targetValue) {
|
|
||||||
keyList.add(entry.getKey());
|
|
||||||
passedMapIterator.remove();
|
|
||||||
resultCount++;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (acsending) {
|
|
||||||
index++;
|
|
||||||
} else {
|
|
||||||
index--;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (resultCount >= numKeys) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// end get the list of keys
|
|
||||||
}
|
|
||||||
|
|
||||||
return keyList;
|
|
||||||
}
|
|
||||||
|
|
||||||
private int getOrderedValue(int[] array, int index) {
|
|
||||||
locate(array, 0, array.length - 1, index);
|
|
||||||
return array[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort the partitions by quick sort, and locate the target index
|
|
||||||
private void locate(int[] array, int left, int right, int index) {
|
|
||||||
|
|
||||||
int mid = (left + right) / 2;
|
|
||||||
//System.out.println(left+" to "+right+" ("+mid+")");
|
|
||||||
|
|
||||||
if (right == left) {
|
|
||||||
//System.out.println("* "+array[targetIndex]);
|
|
||||||
//result=array[targetIndex];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (left < right) {
|
|
||||||
int s = array[mid];
|
|
||||||
int i = left - 1;
|
|
||||||
int j = right + 1;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
while (array[++i] < s);
|
|
||||||
while (array[--j] > s);
|
|
||||||
if (i >= j)
|
|
||||||
break;
|
|
||||||
swap(array, i, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
|
|
||||||
|
|
||||||
if (i > index) {
|
|
||||||
// the target index in the left partition
|
|
||||||
//System.out.println("left partition");
|
|
||||||
locate(array, left, i - 1, index);
|
|
||||||
} else {
|
|
||||||
// the target index in the right partition
|
|
||||||
//System.out.println("right partition");
|
|
||||||
locate(array, j + 1, right, index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void swap(int[] array, int i, int j) {
|
|
||||||
int t = array[i];
|
|
||||||
array[i] = array[j];
|
|
||||||
array[j] = t;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,230 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.audio.properties.FingerprintProperties;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class PairManager {
|
|
||||||
|
|
||||||
FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
|
||||||
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
|
|
||||||
private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks;
|
|
||||||
private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength();
|
|
||||||
private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval();
|
|
||||||
private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance();
|
|
||||||
private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits();
|
|
||||||
|
|
||||||
private int maxPairs;
|
|
||||||
private boolean isReferencePairing;
|
|
||||||
private HashMap<Integer, Boolean> stopPairTable = new HashMap<>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*/
|
|
||||||
public PairManager() {
|
|
||||||
maxPairs = fingerprintProperties.getRefMaxActivePairs();
|
|
||||||
isReferencePairing = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor, number of pairs of robust points depends on the parameter isReferencePairing
|
|
||||||
* no. of pairs of reference and sample can be different due to environmental influence of source
|
|
||||||
* @param isReferencePairing
|
|
||||||
*/
|
|
||||||
public PairManager(boolean isReferencePairing) {
|
|
||||||
if (isReferencePairing) {
|
|
||||||
maxPairs = fingerprintProperties.getRefMaxActivePairs();
|
|
||||||
} else {
|
|
||||||
maxPairs = fingerprintProperties.getSampleMaxActivePairs();
|
|
||||||
}
|
|
||||||
this.isReferencePairing = isReferencePairing;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a pair-positionList table
|
|
||||||
* It's a hash map which the key is the hashed pair, and the value is list of positions
|
|
||||||
* That means the table stores the positions which have the same hashed pair
|
|
||||||
*
|
|
||||||
* @param fingerprint fingerprint bytes
|
|
||||||
* @return pair-positionList HashMap
|
|
||||||
*/
|
|
||||||
public HashMap<Integer, List<Integer>> getPair_PositionList_Table(byte[] fingerprint) {
|
|
||||||
|
|
||||||
List<int[]> pairPositionList = getPairPositionList(fingerprint);
|
|
||||||
|
|
||||||
// table to store pair:pos,pos,pos,...;pair2:pos,pos,pos,....
|
|
||||||
HashMap<Integer, List<Integer>> pair_positionList_table = new HashMap<>();
|
|
||||||
|
|
||||||
// get all pair_positions from list, use a table to collect the data group by pair hashcode
|
|
||||||
for (int[] pair_position : pairPositionList) {
|
|
||||||
//System.out.println(pair_position[0]+","+pair_position[1]);
|
|
||||||
|
|
||||||
// group by pair-hashcode, i.e.: <pair,List<position>>
|
|
||||||
if (pair_positionList_table.containsKey(pair_position[0])) {
|
|
||||||
pair_positionList_table.get(pair_position[0]).add(pair_position[1]);
|
|
||||||
} else {
|
|
||||||
List<Integer> positionList = new LinkedList<>();
|
|
||||||
positionList.add(pair_position[1]);
|
|
||||||
pair_positionList_table.put(pair_position[0], positionList);
|
|
||||||
}
|
|
||||||
// end group by pair-hashcode, i.e.: <pair,List<position>>
|
|
||||||
}
|
|
||||||
// end get all pair_positions from list, use a table to collect the data group by pair hashcode
|
|
||||||
|
|
||||||
return pair_positionList_table;
|
|
||||||
}
|
|
||||||
|
|
||||||
// this return list contains: int[0]=pair_hashcode, int[1]=position
|
|
||||||
private List<int[]> getPairPositionList(byte[] fingerprint) {
|
|
||||||
|
|
||||||
int numFrames = FingerprintManager.getNumFrames(fingerprint);
|
|
||||||
|
|
||||||
// table for paired frames
|
|
||||||
byte[] pairedFrameTable = new byte[numFrames / anchorPointsIntervalLength + 1]; // each second has numAnchorPointsPerSecond pairs only
|
|
||||||
// end table for paired frames
|
|
||||||
|
|
||||||
List<int[]> pairList = new LinkedList<>();
|
|
||||||
List<int[]> sortedCoordinateList = getSortedCoordinateList(fingerprint);
|
|
||||||
|
|
||||||
for (int[] anchorPoint : sortedCoordinateList) {
|
|
||||||
int anchorX = anchorPoint[0];
|
|
||||||
int anchorY = anchorPoint[1];
|
|
||||||
int numPairs = 0;
|
|
||||||
|
|
||||||
for (int[] aSortedCoordinateList : sortedCoordinateList) {
|
|
||||||
|
|
||||||
if (numPairs >= maxPairs) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isReferencePairing && pairedFrameTable[anchorX
|
|
||||||
/ anchorPointsIntervalLength] >= numAnchorPointsPerInterval) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int targetX = aSortedCoordinateList[0];
|
|
||||||
int targetY = aSortedCoordinateList[1];
|
|
||||||
|
|
||||||
if (anchorX == targetX && anchorY == targetY) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// pair up the points
|
|
||||||
int x1, y1, x2, y2; // x2 always >= x1
|
|
||||||
if (targetX >= anchorX) {
|
|
||||||
x2 = targetX;
|
|
||||||
y2 = targetY;
|
|
||||||
x1 = anchorX;
|
|
||||||
y1 = anchorY;
|
|
||||||
} else {
|
|
||||||
x2 = anchorX;
|
|
||||||
y2 = anchorY;
|
|
||||||
x1 = targetX;
|
|
||||||
y1 = targetY;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check target zone
|
|
||||||
if ((x2 - x1) > maxTargetZoneDistance) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// end check target zone
|
|
||||||
|
|
||||||
// check filter bank zone
|
|
||||||
if (!(y1 / bandwidthPerBank == y2 / bandwidthPerBank)) {
|
|
||||||
continue; // same filter bank should have equal value
|
|
||||||
}
|
|
||||||
// end check filter bank zone
|
|
||||||
|
|
||||||
int pairHashcode = (x2 - x1) * numFrequencyUnits * numFrequencyUnits + y2 * numFrequencyUnits + y1;
|
|
||||||
|
|
||||||
// stop list applied on sample pairing only
|
|
||||||
if (!isReferencePairing && stopPairTable.containsKey(pairHashcode)) {
|
|
||||||
numPairs++; // no reservation
|
|
||||||
continue; // escape this point only
|
|
||||||
}
|
|
||||||
// end stop list applied on sample pairing only
|
|
||||||
|
|
||||||
// pass all rules
|
|
||||||
pairList.add(new int[] {pairHashcode, anchorX});
|
|
||||||
pairedFrameTable[anchorX / anchorPointsIntervalLength]++;
|
|
||||||
numPairs++;
|
|
||||||
// end pair up the points
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pairList;
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<int[]> getSortedCoordinateList(byte[] fingerprint) {
|
|
||||||
// each point data is 8 bytes
|
|
||||||
// first 2 bytes is x
|
|
||||||
// next 2 bytes is y
|
|
||||||
// next 4 bytes is intensity
|
|
||||||
|
|
||||||
// get all intensities
|
|
||||||
int numCoordinates = fingerprint.length / 8;
|
|
||||||
int[] intensities = new int[numCoordinates];
|
|
||||||
for (int i = 0; i < numCoordinates; i++) {
|
|
||||||
int pointer = i * 8 + 4;
|
|
||||||
int intensity = (fingerprint[pointer] & 0xff) << 24 | (fingerprint[pointer + 1] & 0xff) << 16
|
|
||||||
| (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
|
|
||||||
intensities[i] = intensity;
|
|
||||||
}
|
|
||||||
|
|
||||||
QuickSortIndexPreserved quicksort = new QuickSortIndexPreserved(intensities);
|
|
||||||
int[] sortIndexes = quicksort.getSortIndexes();
|
|
||||||
|
|
||||||
List<int[]> sortedCoordinateList = new LinkedList<>();
|
|
||||||
for (int i = sortIndexes.length - 1; i >= 0; i--) {
|
|
||||||
int pointer = sortIndexes[i] * 8;
|
|
||||||
int x = (fingerprint[pointer] & 0xff) << 8 | (fingerprint[pointer + 1] & 0xff);
|
|
||||||
int y = (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
|
|
||||||
sortedCoordinateList.add(new int[] {x, y});
|
|
||||||
}
|
|
||||||
return sortedCoordinateList;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert hashed pair to bytes
|
|
||||||
*
|
|
||||||
* @param pairHashcode hashed pair
|
|
||||||
* @return byte array
|
|
||||||
*/
|
|
||||||
public static byte[] pairHashcodeToBytes(int pairHashcode) {
|
|
||||||
return new byte[] {(byte) (pairHashcode >> 8), (byte) pairHashcode};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert bytes to hased pair
|
|
||||||
*
|
|
||||||
* @param pairBytes
|
|
||||||
* @return hashed pair
|
|
||||||
*/
|
|
||||||
public static int pairBytesToHashcode(byte[] pairBytes) {
|
|
||||||
return (pairBytes[0] & 0xFF) << 8 | (pairBytes[1] & 0xFF);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
public abstract class QuickSort {
|
|
||||||
public abstract int[] getSortIndexes();
|
|
||||||
}
|
|
|
@ -1,79 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
public class QuickSortDouble extends QuickSort {
|
|
||||||
|
|
||||||
private int[] indexes;
|
|
||||||
private double[] array;
|
|
||||||
|
|
||||||
public QuickSortDouble(double[] array) {
|
|
||||||
this.array = array;
|
|
||||||
indexes = new int[array.length];
|
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
|
||||||
indexes[i] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getSortIndexes() {
|
|
||||||
sort();
|
|
||||||
return indexes;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void sort() {
|
|
||||||
quicksort(array, indexes, 0, indexes.length - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// quicksort a[left] to a[right]
|
|
||||||
private void quicksort(double[] a, int[] indexes, int left, int right) {
|
|
||||||
if (right <= left)
|
|
||||||
return;
|
|
||||||
int i = partition(a, indexes, left, right);
|
|
||||||
quicksort(a, indexes, left, i - 1);
|
|
||||||
quicksort(a, indexes, i + 1, right);
|
|
||||||
}
|
|
||||||
|
|
||||||
// partition a[left] to a[right], assumes left < right
|
|
||||||
private int partition(double[] a, int[] indexes, int left, int right) {
|
|
||||||
int i = left - 1;
|
|
||||||
int j = right;
|
|
||||||
while (true) {
|
|
||||||
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
|
|
||||||
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
|
|
||||||
if (j == left)
|
|
||||||
break; // don't go out-of-bounds
|
|
||||||
}
|
|
||||||
if (i >= j)
|
|
||||||
break; // check if pointers cross
|
|
||||||
swap(a, indexes, i, j); // swap two elements into place
|
|
||||||
}
|
|
||||||
swap(a, indexes, i, right); // swap with partition element
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// exchange a[i] and a[j]
|
|
||||||
private void swap(double[] a, int[] indexes, int i, int j) {
|
|
||||||
int swap = indexes[i];
|
|
||||||
indexes[i] = indexes[j];
|
|
||||||
indexes[j] = swap;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
public class QuickSortIndexPreserved {
|
|
||||||
|
|
||||||
private QuickSort quickSort;
|
|
||||||
|
|
||||||
public QuickSortIndexPreserved(int[] array) {
|
|
||||||
quickSort = new QuickSortInteger(array);
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuickSortIndexPreserved(double[] array) {
|
|
||||||
quickSort = new QuickSortDouble(array);
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuickSortIndexPreserved(short[] array) {
|
|
||||||
quickSort = new QuickSortShort(array);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getSortIndexes() {
|
|
||||||
return quickSort.getSortIndexes();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,79 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
public class QuickSortInteger extends QuickSort {
|
|
||||||
|
|
||||||
private int[] indexes;
|
|
||||||
private int[] array;
|
|
||||||
|
|
||||||
public QuickSortInteger(int[] array) {
|
|
||||||
this.array = array;
|
|
||||||
indexes = new int[array.length];
|
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
|
||||||
indexes[i] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getSortIndexes() {
|
|
||||||
sort();
|
|
||||||
return indexes;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void sort() {
|
|
||||||
quicksort(array, indexes, 0, indexes.length - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// quicksort a[left] to a[right]
|
|
||||||
private void quicksort(int[] a, int[] indexes, int left, int right) {
|
|
||||||
if (right <= left)
|
|
||||||
return;
|
|
||||||
int i = partition(a, indexes, left, right);
|
|
||||||
quicksort(a, indexes, left, i - 1);
|
|
||||||
quicksort(a, indexes, i + 1, right);
|
|
||||||
}
|
|
||||||
|
|
||||||
// partition a[left] to a[right], assumes left < right
|
|
||||||
private int partition(int[] a, int[] indexes, int left, int right) {
|
|
||||||
int i = left - 1;
|
|
||||||
int j = right;
|
|
||||||
while (true) {
|
|
||||||
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
|
|
||||||
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
|
|
||||||
if (j == left)
|
|
||||||
break; // don't go out-of-bounds
|
|
||||||
}
|
|
||||||
if (i >= j)
|
|
||||||
break; // check if pointers cross
|
|
||||||
swap(a, indexes, i, j); // swap two elements into place
|
|
||||||
}
|
|
||||||
swap(a, indexes, i, right); // swap with partition element
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// exchange a[i] and a[j]
|
|
||||||
private void swap(int[] a, int[] indexes, int i, int j) {
|
|
||||||
int swap = indexes[i];
|
|
||||||
indexes[i] = indexes[j];
|
|
||||||
indexes[j] = swap;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,79 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.fingerprint;
|
|
||||||
|
|
||||||
public class QuickSortShort extends QuickSort {
|
|
||||||
|
|
||||||
private int[] indexes;
|
|
||||||
private short[] array;
|
|
||||||
|
|
||||||
public QuickSortShort(short[] array) {
|
|
||||||
this.array = array;
|
|
||||||
indexes = new int[array.length];
|
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
|
||||||
indexes[i] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getSortIndexes() {
|
|
||||||
sort();
|
|
||||||
return indexes;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void sort() {
|
|
||||||
quicksort(array, indexes, 0, indexes.length - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// quicksort a[left] to a[right]
|
|
||||||
private void quicksort(short[] a, int[] indexes, int left, int right) {
|
|
||||||
if (right <= left)
|
|
||||||
return;
|
|
||||||
int i = partition(a, indexes, left, right);
|
|
||||||
quicksort(a, indexes, left, i - 1);
|
|
||||||
quicksort(a, indexes, i + 1, right);
|
|
||||||
}
|
|
||||||
|
|
||||||
// partition a[left] to a[right], assumes left < right
|
|
||||||
private int partition(short[] a, int[] indexes, int left, int right) {
|
|
||||||
int i = left - 1;
|
|
||||||
int j = right;
|
|
||||||
while (true) {
|
|
||||||
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
|
|
||||||
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
|
|
||||||
if (j == left)
|
|
||||||
break; // don't go out-of-bounds
|
|
||||||
}
|
|
||||||
if (i >= j)
|
|
||||||
break; // check if pointers cross
|
|
||||||
swap(a, indexes, i, j); // swap two elements into place
|
|
||||||
}
|
|
||||||
swap(a, indexes, i, right); // swap with partition element
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// exchange a[i] and a[j]
|
|
||||||
private void swap(short[] a, int[] indexes, int i, int j) {
|
|
||||||
int swap = indexes[i];
|
|
||||||
indexes[i] = indexes[j];
|
|
||||||
indexes[j] = swap;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.formats.input;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.formats.input.BaseInputFormat;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.audio.recordreader.WavFileRecordReader;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* Wave file input format
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class WavInputFormat extends BaseInputFormat {
|
|
||||||
@Override
|
|
||||||
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
|
|
||||||
return createReader(split);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RecordReader createReader(InputSplit split) throws IOException, InterruptedException {
|
|
||||||
RecordReader waveRecordReader = new WavFileRecordReader();
|
|
||||||
waveRecordReader.initialize(split);
|
|
||||||
return waveRecordReader;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,36 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.formats.output;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.exceptions.DataVecException;
|
|
||||||
import org.datavec.api.formats.output.OutputFormat;
|
|
||||||
import org.datavec.api.records.writer.RecordWriter;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class WaveOutputFormat implements OutputFormat {
|
|
||||||
@Override
|
|
||||||
public RecordWriter createWriter(Configuration conf) throws DataVecException {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,139 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.processor;
|
|
||||||
|
|
||||||
public class ArrayRankDouble {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the index position of maximum value the given array
|
|
||||||
* @param array an array
|
|
||||||
* @return index of the max value in array
|
|
||||||
*/
|
|
||||||
public int getMaxValueIndex(double[] array) {
|
|
||||||
|
|
||||||
int index = 0;
|
|
||||||
double max = Integer.MIN_VALUE;
|
|
||||||
|
|
||||||
for (int i = 0; i < array.length; i++) {
|
|
||||||
if (array[i] > max) {
|
|
||||||
max = array[i];
|
|
||||||
index = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the index position of minimum value in the given array
|
|
||||||
* @param array an array
|
|
||||||
* @return index of the min value in array
|
|
||||||
*/
|
|
||||||
public int getMinValueIndex(double[] array) {
|
|
||||||
|
|
||||||
int index = 0;
|
|
||||||
double min = Integer.MAX_VALUE;
|
|
||||||
|
|
||||||
for (int i = 0; i < array.length; i++) {
|
|
||||||
if (array[i] < min) {
|
|
||||||
min = array[i];
|
|
||||||
index = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the n-th value in the array after sorted
|
|
||||||
* @param array an array
|
|
||||||
* @param n position in array
|
|
||||||
* @param ascending is ascending order or not
|
|
||||||
* @return value at nth position of array
|
|
||||||
*/
|
|
||||||
public double getNthOrderedValue(double[] array, int n, boolean ascending) {
|
|
||||||
|
|
||||||
if (n > array.length) {
|
|
||||||
n = array.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
int targetindex;
|
|
||||||
if (ascending) {
|
|
||||||
targetindex = n;
|
|
||||||
} else {
|
|
||||||
targetindex = array.length - n;
|
|
||||||
}
|
|
||||||
|
|
||||||
// this value is the value of the numKey-th element
|
|
||||||
|
|
||||||
return getOrderedValue(array, targetindex);
|
|
||||||
}
|
|
||||||
|
|
||||||
private double getOrderedValue(double[] array, int index) {
|
|
||||||
locate(array, 0, array.length - 1, index);
|
|
||||||
return array[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort the partitions by quick sort, and locate the target index
|
|
||||||
private void locate(double[] array, int left, int right, int index) {
|
|
||||||
|
|
||||||
int mid = (left + right) / 2;
|
|
||||||
// System.out.println(left+" to "+right+" ("+mid+")");
|
|
||||||
|
|
||||||
if (right == left) {
|
|
||||||
// System.out.println("* "+array[targetIndex]);
|
|
||||||
// result=array[targetIndex];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (left < right) {
|
|
||||||
double s = array[mid];
|
|
||||||
int i = left - 1;
|
|
||||||
int j = right + 1;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
while (array[++i] < s);
|
|
||||||
while (array[--j] > s);
|
|
||||||
if (i >= j)
|
|
||||||
break;
|
|
||||||
swap(array, i, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
// System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
|
|
||||||
|
|
||||||
if (i > index) {
|
|
||||||
// the target index in the left partition
|
|
||||||
// System.out.println("left partition");
|
|
||||||
locate(array, left, i - 1, index);
|
|
||||||
} else {
|
|
||||||
// the target index in the right partition
|
|
||||||
// System.out.println("right partition");
|
|
||||||
locate(array, j + 1, right, index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void swap(double[] array, int i, int j) {
|
|
||||||
double t = array[i];
|
|
||||||
array[i] = array[j];
|
|
||||||
array[j] = t;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,28 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.processor;
|
|
||||||
|
|
||||||
public interface IntensityProcessor {
|
|
||||||
|
|
||||||
public void execute();
|
|
||||||
|
|
||||||
public double[][] getIntensities();
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.processor;
|
|
||||||
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ProcessorChain {
|
|
||||||
|
|
||||||
private double[][] intensities;
|
|
||||||
List<IntensityProcessor> processorList = new LinkedList<IntensityProcessor>();
|
|
||||||
|
|
||||||
public ProcessorChain(double[][] intensities) {
|
|
||||||
this.intensities = intensities;
|
|
||||||
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, 1);
|
|
||||||
processorList.add(robustProcessor);
|
|
||||||
process();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void process() {
|
|
||||||
for (IntensityProcessor processor : processorList) {
|
|
||||||
processor.execute();
|
|
||||||
intensities = processor.getIntensities();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[][] getIntensities() {
|
|
||||||
return intensities;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.processor;
|
|
||||||
|
|
||||||
|
|
||||||
public class RobustIntensityProcessor implements IntensityProcessor {
|
|
||||||
|
|
||||||
private double[][] intensities;
|
|
||||||
private int numPointsPerFrame;
|
|
||||||
|
|
||||||
public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) {
|
|
||||||
this.intensities = intensities;
|
|
||||||
this.numPointsPerFrame = numPointsPerFrame;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void execute() {
|
|
||||||
|
|
||||||
int numX = intensities.length;
|
|
||||||
int numY = intensities[0].length;
|
|
||||||
double[][] processedIntensities = new double[numX][numY];
|
|
||||||
|
|
||||||
for (int i = 0; i < numX; i++) {
|
|
||||||
double[] tmpArray = new double[numY];
|
|
||||||
System.arraycopy(intensities[i], 0, tmpArray, 0, numY);
|
|
||||||
|
|
||||||
// pass value is the last some elements in sorted array
|
|
||||||
ArrayRankDouble arrayRankDouble = new ArrayRankDouble();
|
|
||||||
double passValue = arrayRankDouble.getNthOrderedValue(tmpArray, numPointsPerFrame, false);
|
|
||||||
|
|
||||||
// only passed elements will be assigned a value
|
|
||||||
for (int j = 0; j < numY; j++) {
|
|
||||||
if (intensities[i][j] >= passValue) {
|
|
||||||
processedIntensities[i][j] = intensities[i][j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
intensities = processedIntensities;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[][] getIntensities() {
|
|
||||||
return intensities;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,49 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.processor;
|
|
||||||
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
public class TopManyPointsProcessorChain {
|
|
||||||
|
|
||||||
private double[][] intensities;
|
|
||||||
List<IntensityProcessor> processorList = new LinkedList<>();
|
|
||||||
|
|
||||||
public TopManyPointsProcessorChain(double[][] intensities, int numPoints) {
|
|
||||||
this.intensities = intensities;
|
|
||||||
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, numPoints);
|
|
||||||
processorList.add(robustProcessor);
|
|
||||||
process();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void process() {
|
|
||||||
for (IntensityProcessor processor : processorList) {
|
|
||||||
processor.execute();
|
|
||||||
intensities = processor.getIntensities();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[][] getIntensities() {
|
|
||||||
return intensities;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,121 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.properties;
|
|
||||||
|
|
||||||
public class FingerprintProperties {
|
|
||||||
|
|
||||||
protected static FingerprintProperties instance = null;
|
|
||||||
|
|
||||||
private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint
|
|
||||||
private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size
|
|
||||||
private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ... 32
|
|
||||||
private int numFilterBanks = 4;
|
|
||||||
|
|
||||||
private int upperBoundedFrequency = 1500; // low pass
|
|
||||||
private int lowerBoundedFrequency = 400; // high pass
|
|
||||||
private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps)
|
|
||||||
private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps
|
|
||||||
private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually
|
|
||||||
|
|
||||||
private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs
|
|
||||||
private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip
|
|
||||||
private int numAnchorPointsPerInterval = 10;
|
|
||||||
private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second)
|
|
||||||
private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second)
|
|
||||||
|
|
||||||
private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units
|
|
||||||
|
|
||||||
public static FingerprintProperties getInstance() {
|
|
||||||
if (instance == null) {
|
|
||||||
synchronized (FingerprintProperties.class) {
|
|
||||||
if (instance == null) {
|
|
||||||
instance = new FingerprintProperties();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumRobustPointsPerFrame() {
|
|
||||||
return numRobustPointsPerFrame;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getSampleSizePerFrame() {
|
|
||||||
return sampleSizePerFrame;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getOverlapFactor() {
|
|
||||||
return overlapFactor;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumFilterBanks() {
|
|
||||||
return numFilterBanks;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getUpperBoundedFrequency() {
|
|
||||||
return upperBoundedFrequency;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getLowerBoundedFrequency() {
|
|
||||||
return lowerBoundedFrequency;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getFps() {
|
|
||||||
return fps;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getRefMaxActivePairs() {
|
|
||||||
return refMaxActivePairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getSampleMaxActivePairs() {
|
|
||||||
return sampleMaxActivePairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumAnchorPointsPerInterval() {
|
|
||||||
return numAnchorPointsPerInterval;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getAnchorPointsIntervalLength() {
|
|
||||||
return anchorPointsIntervalLength;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMaxTargetZoneDistance() {
|
|
||||||
return maxTargetZoneDistance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumFrequencyUnits() {
|
|
||||||
return numFrequencyUnits;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMaxPossiblePairHashcode() {
|
|
||||||
return maxTargetZoneDistance * numFrequencyUnits * numFrequencyUnits + numFrequencyUnits * numFrequencyUnits
|
|
||||||
+ numFrequencyUnits;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getSampleRate() {
|
|
||||||
return sampleRate;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumFramesInOneSecond() {
|
|
||||||
return numFramesInOneSecond;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,225 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.recordreader;
|
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.Record;
|
|
||||||
import org.datavec.api.records.metadata.RecordMetaData;
|
|
||||||
import org.datavec.api.records.reader.BaseRecordReader;
|
|
||||||
import org.datavec.api.split.BaseInputSplit;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.api.split.InputStreamInputSplit;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
|
|
||||||
import java.io.DataInputStream;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.nio.file.Path;
|
|
||||||
import java.nio.file.Paths;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Base audio file loader
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public abstract class BaseAudioRecordReader extends BaseRecordReader {
|
|
||||||
private Iterator<File> iter;
|
|
||||||
private List<Writable> record;
|
|
||||||
private boolean hitImage = false;
|
|
||||||
private boolean appendLabel = false;
|
|
||||||
private List<String> labels = new ArrayList<>();
|
|
||||||
private Configuration conf;
|
|
||||||
protected InputSplit inputSplit;
|
|
||||||
|
|
||||||
public BaseAudioRecordReader() {}
|
|
||||||
|
|
||||||
public BaseAudioRecordReader(boolean appendLabel, List<String> labels) {
|
|
||||||
this.appendLabel = appendLabel;
|
|
||||||
this.labels = labels;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseAudioRecordReader(List<String> labels) {
|
|
||||||
this.labels = labels;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseAudioRecordReader(boolean appendLabel) {
|
|
||||||
this.appendLabel = appendLabel;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract List<Writable> loadData(File file, InputStream inputStream) throws IOException;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(InputSplit split) throws IOException, InterruptedException {
|
|
||||||
inputSplit = split;
|
|
||||||
if (split instanceof BaseInputSplit) {
|
|
||||||
URI[] locations = split.locations();
|
|
||||||
if (locations != null && locations.length >= 1) {
|
|
||||||
if (locations.length > 1) {
|
|
||||||
List<File> allFiles = new ArrayList<>();
|
|
||||||
for (URI location : locations) {
|
|
||||||
File iter = new File(location);
|
|
||||||
if (iter.isDirectory()) {
|
|
||||||
Iterator<File> allFiles2 = FileUtils.iterateFiles(iter, null, true);
|
|
||||||
while (allFiles2.hasNext())
|
|
||||||
allFiles.add(allFiles2.next());
|
|
||||||
}
|
|
||||||
|
|
||||||
else
|
|
||||||
allFiles.add(iter);
|
|
||||||
}
|
|
||||||
|
|
||||||
iter = allFiles.iterator();
|
|
||||||
} else {
|
|
||||||
File curr = new File(locations[0]);
|
|
||||||
if (curr.isDirectory())
|
|
||||||
iter = FileUtils.iterateFiles(curr, null, true);
|
|
||||||
else
|
|
||||||
iter = Collections.singletonList(curr).iterator();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
else if (split instanceof InputStreamInputSplit) {
|
|
||||||
record = new ArrayList<>();
|
|
||||||
InputStreamInputSplit split2 = (InputStreamInputSplit) split;
|
|
||||||
InputStream is = split2.getIs();
|
|
||||||
URI[] locations = split2.locations();
|
|
||||||
if (appendLabel) {
|
|
||||||
Path path = Paths.get(locations[0]);
|
|
||||||
String parent = path.getParent().toString();
|
|
||||||
record.add(new DoubleWritable(labels.indexOf(parent)));
|
|
||||||
}
|
|
||||||
|
|
||||||
is.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
|
||||||
this.conf = conf;
|
|
||||||
this.appendLabel = conf.getBoolean(APPEND_LABEL, false);
|
|
||||||
this.labels = new ArrayList<>(conf.getStringCollection(LABELS));
|
|
||||||
initialize(split);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Writable> next() {
|
|
||||||
if (iter != null) {
|
|
||||||
File next = iter.next();
|
|
||||||
invokeListeners(next);
|
|
||||||
try {
|
|
||||||
return loadData(next, null);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
} else if (record != null) {
|
|
||||||
hitImage = true;
|
|
||||||
return record;
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasNext() {
|
|
||||||
if (iter != null) {
|
|
||||||
return iter.hasNext();
|
|
||||||
} else if (record != null) {
|
|
||||||
return !hitImage;
|
|
||||||
}
|
|
||||||
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setConf(Configuration conf) {
|
|
||||||
this.conf = conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Configuration getConf() {
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getLabels() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
if (inputSplit == null)
|
|
||||||
throw new UnsupportedOperationException("Cannot reset without first initializing");
|
|
||||||
try {
|
|
||||||
initialize(inputSplit);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException("Error during LineRecordReader reset", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean resetSupported(){
|
|
||||||
if(inputSplit == null){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return inputSplit.resetSupported();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
|
|
||||||
invokeListeners(uri);
|
|
||||||
try {
|
|
||||||
return loadData(null, dataInputStream);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Record nextRecord() {
|
|
||||||
return new org.datavec.api.records.impl.Record(next(), null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
|
|
||||||
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
|
|
||||||
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,76 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.recordreader;
|
|
||||||
|
|
||||||
import org.bytedeco.javacv.FFmpegFrameGrabber;
|
|
||||||
import org.bytedeco.javacv.Frame;
|
|
||||||
import org.datavec.api.writable.FloatWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.nio.FloatBuffer;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.bytedeco.ffmpeg.global.avutil.AV_SAMPLE_FMT_FLT;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Native audio file loader using FFmpeg.
|
|
||||||
*
|
|
||||||
* @author saudet
|
|
||||||
*/
|
|
||||||
public class NativeAudioRecordReader extends BaseAudioRecordReader {
|
|
||||||
|
|
||||||
public NativeAudioRecordReader() {}
|
|
||||||
|
|
||||||
public NativeAudioRecordReader(boolean appendLabel, List<String> labels) {
|
|
||||||
super(appendLabel, labels);
|
|
||||||
}
|
|
||||||
|
|
||||||
public NativeAudioRecordReader(List<String> labels) {
|
|
||||||
super(labels);
|
|
||||||
}
|
|
||||||
|
|
||||||
public NativeAudioRecordReader(boolean appendLabel) {
|
|
||||||
super(appendLabel);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
|
|
||||||
List<Writable> ret = new ArrayList<>();
|
|
||||||
try (FFmpegFrameGrabber grabber = inputStream != null ? new FFmpegFrameGrabber(inputStream)
|
|
||||||
: new FFmpegFrameGrabber(file.getAbsolutePath())) {
|
|
||||||
grabber.setSampleFormat(AV_SAMPLE_FMT_FLT);
|
|
||||||
grabber.start();
|
|
||||||
Frame frame;
|
|
||||||
while ((frame = grabber.grab()) != null) {
|
|
||||||
while (frame.samples != null && frame.samples[0].hasRemaining()) {
|
|
||||||
for (int i = 0; i < frame.samples.length; i++) {
|
|
||||||
ret.add(new FloatWritable(((FloatBuffer) frame.samples[i]).get()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,57 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio.recordreader;
|
|
||||||
|
|
||||||
import org.datavec.api.util.RecordUtils;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.audio.Wave;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wav file loader
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class WavFileRecordReader extends BaseAudioRecordReader {
|
|
||||||
|
|
||||||
public WavFileRecordReader() {}
|
|
||||||
|
|
||||||
public WavFileRecordReader(boolean appendLabel, List<String> labels) {
|
|
||||||
super(appendLabel, labels);
|
|
||||||
}
|
|
||||||
|
|
||||||
public WavFileRecordReader(List<String> labels) {
|
|
||||||
super(labels);
|
|
||||||
}
|
|
||||||
|
|
||||||
public WavFileRecordReader(boolean appendLabel) {
|
|
||||||
super(appendLabel);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
|
|
||||||
Wave wave = inputStream != null ? new Wave(inputStream) : new Wave(file.getAbsolutePath());
|
|
||||||
return RecordUtils.toRecord(wave.getNormalizedAmplitudes());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
package org.datavec.audio;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 60000;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.datavec.audio";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseND4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio;
|
|
||||||
|
|
||||||
import org.bytedeco.javacv.FFmpegFrameRecorder;
|
|
||||||
import org.bytedeco.javacv.Frame;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.FileSplit;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.audio.recordreader.NativeAudioRecordReader;
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.nio.ShortBuffer;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_VORBIS;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author saudet
|
|
||||||
*/
|
|
||||||
public class AudioReaderTest extends BaseND4JTest {
|
|
||||||
@Ignore
|
|
||||||
@Test
|
|
||||||
public void testNativeAudioReader() throws Exception {
|
|
||||||
File tempFile = File.createTempFile("testNativeAudioReader", ".ogg");
|
|
||||||
FFmpegFrameRecorder recorder = new FFmpegFrameRecorder(tempFile, 2);
|
|
||||||
recorder.setAudioCodec(AV_CODEC_ID_VORBIS);
|
|
||||||
recorder.setSampleRate(44100);
|
|
||||||
recorder.start();
|
|
||||||
Frame audioFrame = new Frame();
|
|
||||||
ShortBuffer audioBuffer = ShortBuffer.allocate(64 * 1024);
|
|
||||||
audioFrame.sampleRate = 44100;
|
|
||||||
audioFrame.audioChannels = 2;
|
|
||||||
audioFrame.samples = new ShortBuffer[] {audioBuffer};
|
|
||||||
recorder.record(audioFrame);
|
|
||||||
recorder.stop();
|
|
||||||
recorder.release();
|
|
||||||
|
|
||||||
RecordReader reader = new NativeAudioRecordReader();
|
|
||||||
reader.initialize(new FileSplit(tempFile));
|
|
||||||
assertTrue(reader.hasNext());
|
|
||||||
List<Writable> record = reader.next();
|
|
||||||
assertEquals(audioBuffer.limit(), record.size());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.audio;
|
|
||||||
|
|
||||||
import org.datavec.audio.dsp.FastFourierTransform;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
public class TestFastFourierTransform extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testFastFourierTransformComplex() {
|
|
||||||
FastFourierTransform fft = new FastFourierTransform();
|
|
||||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
|
|
||||||
double[] frequencies = fft.getMagnitudes(amplitudes);
|
|
||||||
|
|
||||||
Assert.assertEquals(2, frequencies.length);
|
|
||||||
Assert.assertArrayEquals(new double[] {21.335, 18.513}, frequencies, 0.005);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testFastFourierTransformComplexLong() {
|
|
||||||
FastFourierTransform fft = new FastFourierTransform();
|
|
||||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
|
|
||||||
double[] frequencies = fft.getMagnitudes(amplitudes, true);
|
|
||||||
|
|
||||||
Assert.assertEquals(4, frequencies.length);
|
|
||||||
Assert.assertArrayEquals(new double[] {21.335, 18.5132, 14.927, 7.527}, frequencies, 0.005);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testFastFourierTransformReal() {
|
|
||||||
FastFourierTransform fft = new FastFourierTransform();
|
|
||||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
|
|
||||||
double[] frequencies = fft.getMagnitudes(amplitudes, false);
|
|
||||||
|
|
||||||
Assert.assertEquals(4, frequencies.length);
|
|
||||||
Assert.assertArrayEquals(new double[] {28.8, 2.107, 14.927, 19.874}, frequencies, 0.005);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testFastFourierTransformRealOddSize() {
|
|
||||||
FastFourierTransform fft = new FastFourierTransform();
|
|
||||||
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5};
|
|
||||||
double[] frequencies = fft.getMagnitudes(amplitudes, false);
|
|
||||||
|
|
||||||
Assert.assertEquals(3, frequencies.length);
|
|
||||||
Assert.assertArrayEquals(new double[] {24.2, 3.861, 16.876}, frequencies, 0.005);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,71 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-data-codec</artifactId>
|
|
||||||
|
|
||||||
<name>datavec-data-codec</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-api</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data-image</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.jcodec</groupId>
|
|
||||||
<artifactId>jcodec</artifactId>
|
|
||||||
<version>0.1.5</version>
|
|
||||||
</dependency>
|
|
||||||
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
|
|
||||||
<!--
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>ffmpeg-platform</artifactId>
|
|
||||||
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
|
|
||||||
</dependency>
|
|
||||||
-->
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.codec.format.input;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.formats.input.BaseInputFormat;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.codec.reader.CodecRecordReader;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class CodecInputFormat extends BaseInputFormat {
|
|
||||||
@Override
|
|
||||||
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
|
|
||||||
RecordReader reader = new CodecRecordReader();
|
|
||||||
reader.initialize(conf, split);
|
|
||||||
return reader;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,144 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.codec.reader;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.SequenceRecord;
|
|
||||||
import org.datavec.api.records.metadata.RecordMetaData;
|
|
||||||
import org.datavec.api.records.metadata.RecordMetaDataURI;
|
|
||||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
|
||||||
import org.datavec.api.records.reader.impl.FileRecordReader;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
|
|
||||||
import java.io.DataInputStream;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public abstract class BaseCodecRecordReader extends FileRecordReader implements SequenceRecordReader {
|
|
||||||
protected int startFrame = 0;
|
|
||||||
protected int numFrames = -1;
|
|
||||||
protected int totalFrames = -1;
|
|
||||||
protected double framesPerSecond = -1;
|
|
||||||
protected double videoLength = -1;
|
|
||||||
protected int rows = 28, cols = 28;
|
|
||||||
protected boolean ravel = false;
|
|
||||||
|
|
||||||
public final static String NAME_SPACE = "org.datavec.codec.reader";
|
|
||||||
public final static String ROWS = NAME_SPACE + ".rows";
|
|
||||||
public final static String COLUMNS = NAME_SPACE + ".columns";
|
|
||||||
public final static String START_FRAME = NAME_SPACE + ".startframe";
|
|
||||||
public final static String TOTAL_FRAMES = NAME_SPACE + ".frames";
|
|
||||||
public final static String TIME_SLICE = NAME_SPACE + ".time";
|
|
||||||
public final static String RAVEL = NAME_SPACE + ".ravel";
|
|
||||||
public final static String VIDEO_DURATION = NAME_SPACE + ".duration";
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<List<Writable>> sequenceRecord() {
|
|
||||||
URI next = locationsIterator.next();
|
|
||||||
|
|
||||||
try (InputStream s = streamCreatorFn.apply(next)){
|
|
||||||
return loadData(null, s);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException {
|
|
||||||
return loadData(null, dataInputStream);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException;
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
|
||||||
setConf(conf);
|
|
||||||
initialize(split);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Writable> next() {
|
|
||||||
throw new UnsupportedOperationException("next() not supported for CodecRecordReader (use: sequenceRecord)");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
|
|
||||||
throw new UnsupportedOperationException("record(URI,DataInputStream) not supported for CodecRecordReader");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setConf(Configuration conf) {
|
|
||||||
super.setConf(conf);
|
|
||||||
startFrame = conf.getInt(START_FRAME, 0);
|
|
||||||
numFrames = conf.getInt(TOTAL_FRAMES, -1);
|
|
||||||
rows = conf.getInt(ROWS, 28);
|
|
||||||
cols = conf.getInt(COLUMNS, 28);
|
|
||||||
framesPerSecond = conf.getFloat(TIME_SLICE, -1);
|
|
||||||
videoLength = conf.getFloat(VIDEO_DURATION, -1);
|
|
||||||
ravel = conf.getBoolean(RAVEL, false);
|
|
||||||
totalFrames = conf.getInt(TOTAL_FRAMES, -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Configuration getConf() {
|
|
||||||
return super.getConf();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SequenceRecord nextSequence() {
|
|
||||||
URI next = locationsIterator.next();
|
|
||||||
|
|
||||||
List<List<Writable>> list;
|
|
||||||
try (InputStream s = streamCreatorFn.apply(next)){
|
|
||||||
list = loadData(null, s);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
return new org.datavec.api.records.impl.SequenceRecord(list,
|
|
||||||
new RecordMetaDataURI(next, CodecRecordReader.class));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException {
|
|
||||||
return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
|
|
||||||
List<SequenceRecord> out = new ArrayList<>();
|
|
||||||
for (RecordMetaData meta : recordMetaDatas) {
|
|
||||||
try (InputStream s = streamCreatorFn.apply(meta.getURI())){
|
|
||||||
List<List<Writable>> list = loadData(null, s);
|
|
||||||
out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,138 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.codec.reader;
|
|
||||||
|
|
||||||
import org.apache.commons.compress.utils.IOUtils;
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.image.loader.ImageLoader;
|
|
||||||
import org.jcodec.api.FrameGrab;
|
|
||||||
import org.jcodec.api.JCodecException;
|
|
||||||
import org.jcodec.common.ByteBufferSeekableByteChannel;
|
|
||||||
import org.jcodec.common.NIOUtils;
|
|
||||||
import org.jcodec.common.SeekableByteChannel;
|
|
||||||
|
|
||||||
import java.awt.image.BufferedImage;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class CodecRecordReader extends BaseCodecRecordReader {
|
|
||||||
|
|
||||||
private ImageLoader imageLoader;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setConf(Configuration conf) {
|
|
||||||
super.setConf(conf);
|
|
||||||
imageLoader = new ImageLoader(rows, cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
|
|
||||||
SeekableByteChannel seekableByteChannel;
|
|
||||||
if (inputStream != null) {
|
|
||||||
//Reading video from DataInputStream: Need data from this stream in a SeekableByteChannel
|
|
||||||
//Approach used here: load entire video into memory -> ByteBufferSeekableByteChanel
|
|
||||||
byte[] data = IOUtils.toByteArray(inputStream);
|
|
||||||
ByteBuffer bb = ByteBuffer.wrap(data);
|
|
||||||
seekableByteChannel = new FixedByteBufferSeekableByteChannel(bb);
|
|
||||||
} else {
|
|
||||||
seekableByteChannel = NIOUtils.readableFileChannel(file);
|
|
||||||
}
|
|
||||||
|
|
||||||
List<List<Writable>> record = new ArrayList<>();
|
|
||||||
|
|
||||||
if (numFrames >= 1) {
|
|
||||||
FrameGrab fg;
|
|
||||||
try {
|
|
||||||
fg = new FrameGrab(seekableByteChannel);
|
|
||||||
if (startFrame != 0)
|
|
||||||
fg.seekToFramePrecise(startFrame);
|
|
||||||
} catch (JCodecException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = startFrame; i < startFrame + numFrames; i++) {
|
|
||||||
try {
|
|
||||||
BufferedImage grab = fg.getFrame();
|
|
||||||
if (ravel)
|
|
||||||
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
|
|
||||||
else
|
|
||||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (framesPerSecond < 1)
|
|
||||||
throw new IllegalStateException("No frames or frame time intervals specified");
|
|
||||||
|
|
||||||
|
|
||||||
else {
|
|
||||||
for (double i = 0; i < videoLength; i += framesPerSecond) {
|
|
||||||
try {
|
|
||||||
BufferedImage grab = FrameGrab.getFrame(seekableByteChannel, i);
|
|
||||||
if (ravel)
|
|
||||||
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
|
|
||||||
else
|
|
||||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return record;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */
|
|
||||||
private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel {
|
|
||||||
private ByteBuffer backing;
|
|
||||||
|
|
||||||
public FixedByteBufferSeekableByteChannel(ByteBuffer backing) {
|
|
||||||
super(backing);
|
|
||||||
try {
|
|
||||||
Field f = this.getClass().getSuperclass().getDeclaredField("maxPos");
|
|
||||||
f.setAccessible(true);
|
|
||||||
f.set(this, backing.limit());
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
this.backing = backing;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int read(ByteBuffer dst) throws IOException {
|
|
||||||
if (!backing.hasRemaining())
|
|
||||||
return -1;
|
|
||||||
return super.read(dst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,82 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.codec.reader;
|
|
||||||
|
|
||||||
import org.bytedeco.javacv.FFmpegFrameGrabber;
|
|
||||||
import org.bytedeco.javacv.Frame;
|
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class NativeCodecRecordReader extends BaseCodecRecordReader {
|
|
||||||
|
|
||||||
private OpenCVFrameConverter.ToMat converter;
|
|
||||||
private NativeImageLoader imageLoader;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setConf(Configuration conf) {
|
|
||||||
super.setConf(conf);
|
|
||||||
converter = new OpenCVFrameConverter.ToMat();
|
|
||||||
imageLoader = new NativeImageLoader(rows, cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
|
|
||||||
List<List<Writable>> record = new ArrayList<>();
|
|
||||||
|
|
||||||
try (FFmpegFrameGrabber fg =
|
|
||||||
inputStream != null ? new FFmpegFrameGrabber(inputStream) : new FFmpegFrameGrabber(file)) {
|
|
||||||
if (numFrames >= 1) {
|
|
||||||
fg.start();
|
|
||||||
if (startFrame != 0)
|
|
||||||
fg.setFrameNumber(startFrame);
|
|
||||||
|
|
||||||
for (int i = startFrame; i < startFrame + numFrames; i++) {
|
|
||||||
Frame grab = fg.grabImage();
|
|
||||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (framesPerSecond < 1)
|
|
||||||
throw new IllegalStateException("No frames or frame time intervals specified");
|
|
||||||
else {
|
|
||||||
fg.start();
|
|
||||||
|
|
||||||
for (double i = 0; i < videoLength; i += framesPerSecond) {
|
|
||||||
fg.setTimestamp(Math.round(i * 1000000L));
|
|
||||||
Frame grab = fg.grabImage();
|
|
||||||
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return record;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
package org.datavec.codec.reader;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.datavec.codec.reader";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseND4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,212 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.codec.reader;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.SequenceRecord;
|
|
||||||
import org.datavec.api.records.metadata.RecordMetaData;
|
|
||||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
|
||||||
import org.datavec.api.split.FileSplit;
|
|
||||||
import org.datavec.api.writable.ArrayWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.DataInputStream;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class CodecReaderTest {
|
|
||||||
@Test
|
|
||||||
public void testCodecReader() throws Exception {
|
|
||||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
|
||||||
SequenceRecordReader reader = new CodecRecordReader();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.set(CodecRecordReader.RAVEL, "true");
|
|
||||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
|
||||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
|
||||||
conf.set(CodecRecordReader.ROWS, "80");
|
|
||||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
|
||||||
reader.initialize(new FileSplit(file));
|
|
||||||
reader.setConf(conf);
|
|
||||||
assertTrue(reader.hasNext());
|
|
||||||
List<List<Writable>> record = reader.sequenceRecord();
|
|
||||||
// System.out.println(record.size());
|
|
||||||
|
|
||||||
Iterator<List<Writable>> it = record.iterator();
|
|
||||||
List<Writable> first = it.next();
|
|
||||||
// System.out.println(first);
|
|
||||||
|
|
||||||
//Expected size: 80x46x3
|
|
||||||
assertEquals(1, first.size());
|
|
||||||
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCodecReaderMeta() throws Exception {
|
|
||||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
|
||||||
SequenceRecordReader reader = new CodecRecordReader();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.set(CodecRecordReader.RAVEL, "true");
|
|
||||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
|
||||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
|
||||||
conf.set(CodecRecordReader.ROWS, "80");
|
|
||||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
|
||||||
reader.initialize(new FileSplit(file));
|
|
||||||
reader.setConf(conf);
|
|
||||||
assertTrue(reader.hasNext());
|
|
||||||
List<List<Writable>> record = reader.sequenceRecord();
|
|
||||||
assertEquals(500, record.size()); //500 frames
|
|
||||||
|
|
||||||
reader.reset();
|
|
||||||
SequenceRecord seqR = reader.nextSequence();
|
|
||||||
assertEquals(record, seqR.getSequenceRecord());
|
|
||||||
RecordMetaData meta = seqR.getMetaData();
|
|
||||||
// System.out.println(meta);
|
|
||||||
assertTrue(meta.getURI().toString().endsWith(file.getName()));
|
|
||||||
|
|
||||||
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
|
|
||||||
assertEquals(seqR, fromMeta);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testViaDataInputStream() throws Exception {
|
|
||||||
|
|
||||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
|
||||||
SequenceRecordReader reader = new CodecRecordReader();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.set(CodecRecordReader.RAVEL, "true");
|
|
||||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
|
||||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
|
||||||
conf.set(CodecRecordReader.ROWS, "80");
|
|
||||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
|
||||||
|
|
||||||
Configuration conf2 = new Configuration(conf);
|
|
||||||
|
|
||||||
reader.initialize(new FileSplit(file));
|
|
||||||
reader.setConf(conf);
|
|
||||||
assertTrue(reader.hasNext());
|
|
||||||
List<List<Writable>> expected = reader.sequenceRecord();
|
|
||||||
|
|
||||||
|
|
||||||
SequenceRecordReader reader2 = new CodecRecordReader();
|
|
||||||
reader2.setConf(conf2);
|
|
||||||
|
|
||||||
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
|
|
||||||
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
|
|
||||||
|
|
||||||
assertEquals(expected, actual);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
|
||||||
public void testNativeCodecReader() throws Exception {
|
|
||||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
|
||||||
SequenceRecordReader reader = new NativeCodecRecordReader();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.set(CodecRecordReader.RAVEL, "true");
|
|
||||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
|
||||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
|
||||||
conf.set(CodecRecordReader.ROWS, "80");
|
|
||||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
|
||||||
reader.initialize(new FileSplit(file));
|
|
||||||
reader.setConf(conf);
|
|
||||||
assertTrue(reader.hasNext());
|
|
||||||
List<List<Writable>> record = reader.sequenceRecord();
|
|
||||||
// System.out.println(record.size());
|
|
||||||
|
|
||||||
Iterator<List<Writable>> it = record.iterator();
|
|
||||||
List<Writable> first = it.next();
|
|
||||||
// System.out.println(first);
|
|
||||||
|
|
||||||
//Expected size: 80x46x3
|
|
||||||
assertEquals(1, first.size());
|
|
||||||
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
|
||||||
public void testNativeCodecReaderMeta() throws Exception {
|
|
||||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
|
||||||
SequenceRecordReader reader = new NativeCodecRecordReader();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.set(CodecRecordReader.RAVEL, "true");
|
|
||||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
|
||||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
|
||||||
conf.set(CodecRecordReader.ROWS, "80");
|
|
||||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
|
||||||
reader.initialize(new FileSplit(file));
|
|
||||||
reader.setConf(conf);
|
|
||||||
assertTrue(reader.hasNext());
|
|
||||||
List<List<Writable>> record = reader.sequenceRecord();
|
|
||||||
assertEquals(500, record.size()); //500 frames
|
|
||||||
|
|
||||||
reader.reset();
|
|
||||||
SequenceRecord seqR = reader.nextSequence();
|
|
||||||
assertEquals(record, seqR.getSequenceRecord());
|
|
||||||
RecordMetaData meta = seqR.getMetaData();
|
|
||||||
// System.out.println(meta);
|
|
||||||
assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4"));
|
|
||||||
|
|
||||||
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
|
|
||||||
assertEquals(seqR, fromMeta);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
|
||||||
public void testNativeViaDataInputStream() throws Exception {
|
|
||||||
|
|
||||||
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
|
|
||||||
SequenceRecordReader reader = new NativeCodecRecordReader();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.set(CodecRecordReader.RAVEL, "true");
|
|
||||||
conf.set(CodecRecordReader.START_FRAME, "160");
|
|
||||||
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
|
|
||||||
conf.set(CodecRecordReader.ROWS, "80");
|
|
||||||
conf.set(CodecRecordReader.COLUMNS, "46");
|
|
||||||
|
|
||||||
Configuration conf2 = new Configuration(conf);
|
|
||||||
|
|
||||||
reader.initialize(new FileSplit(file));
|
|
||||||
reader.setConf(conf);
|
|
||||||
assertTrue(reader.hasNext());
|
|
||||||
List<List<Writable>> expected = reader.sequenceRecord();
|
|
||||||
|
|
||||||
|
|
||||||
SequenceRecordReader reader2 = new NativeCodecRecordReader();
|
|
||||||
reader2.setConf(conf2);
|
|
||||||
|
|
||||||
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
|
|
||||||
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
|
|
||||||
|
|
||||||
assertEquals(expected, actual);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,77 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-data-nlp</artifactId>
|
|
||||||
|
|
||||||
<name>datavec-data-nlp</name>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<cleartk.version>2.0.0</cleartk.version>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-api</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-local</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.cleartk</groupId>
|
|
||||||
<artifactId>cleartk-snowball</artifactId>
|
|
||||||
<version>${cleartk.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.cleartk</groupId>
|
|
||||||
<artifactId>cleartk-opennlp-tools</artifactId>
|
|
||||||
<version>${cleartk.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,237 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.annotator;
|
|
||||||
|
|
||||||
import opennlp.tools.postag.POSModel;
|
|
||||||
import opennlp.tools.postag.POSTaggerME;
|
|
||||||
import opennlp.uima.postag.POSModelResource;
|
|
||||||
import opennlp.uima.postag.POSModelResourceImpl;
|
|
||||||
import opennlp.uima.util.AnnotationComboIterator;
|
|
||||||
import opennlp.uima.util.AnnotationIteratorPair;
|
|
||||||
import opennlp.uima.util.AnnotatorUtil;
|
|
||||||
import opennlp.uima.util.UimaUtil;
|
|
||||||
import org.apache.uima.UimaContext;
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
|
||||||
import org.apache.uima.cas.CAS;
|
|
||||||
import org.apache.uima.cas.Feature;
|
|
||||||
import org.apache.uima.cas.Type;
|
|
||||||
import org.apache.uima.cas.TypeSystem;
|
|
||||||
import org.apache.uima.cas.text.AnnotationFS;
|
|
||||||
import org.apache.uima.fit.component.CasAnnotator_ImplBase;
|
|
||||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
|
||||||
import org.apache.uima.fit.factory.ExternalResourceFactory;
|
|
||||||
import org.apache.uima.resource.ResourceAccessException;
|
|
||||||
import org.apache.uima.resource.ResourceInitializationException;
|
|
||||||
import org.apache.uima.util.Level;
|
|
||||||
import org.apache.uima.util.Logger;
|
|
||||||
import org.cleartk.token.type.Sentence;
|
|
||||||
import org.cleartk.token.type.Token;
|
|
||||||
import org.datavec.nlp.movingwindow.Util;
|
|
||||||
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
public class PoStagger extends CasAnnotator_ImplBase {
|
|
||||||
|
|
||||||
static {
|
|
||||||
//UIMA logging
|
|
||||||
Util.disableLogging();
|
|
||||||
}
|
|
||||||
|
|
||||||
private POSTaggerME posTagger;
|
|
||||||
|
|
||||||
private Type sentenceType;
|
|
||||||
|
|
||||||
private Type tokenType;
|
|
||||||
|
|
||||||
private Feature posFeature;
|
|
||||||
|
|
||||||
private Feature probabilityFeature;
|
|
||||||
|
|
||||||
private UimaContext context;
|
|
||||||
|
|
||||||
private Logger logger;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes a new instance.
|
|
||||||
*
|
|
||||||
* Note: Use {@link #initialize(org.apache.uima.UimaContext) } to initialize this instance. Not use the
|
|
||||||
* constructor.
|
|
||||||
*/
|
|
||||||
public PoStagger() {
|
|
||||||
// must not be implemented !
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes the current instance with the given context.
|
|
||||||
*
|
|
||||||
* Note: Do all initialization in this method, do not use the constructor.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void initialize(UimaContext context) throws ResourceInitializationException {
|
|
||||||
|
|
||||||
super.initialize(context);
|
|
||||||
|
|
||||||
this.context = context;
|
|
||||||
|
|
||||||
this.logger = context.getLogger();
|
|
||||||
|
|
||||||
if (this.logger.isLoggable(Level.INFO)) {
|
|
||||||
this.logger.log(Level.INFO, "Initializing the OpenNLP " + "Part of Speech annotator.");
|
|
||||||
}
|
|
||||||
|
|
||||||
POSModel model;
|
|
||||||
|
|
||||||
try {
|
|
||||||
POSModelResource modelResource = (POSModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
|
|
||||||
|
|
||||||
model = modelResource.getModel();
|
|
||||||
} catch (ResourceAccessException e) {
|
|
||||||
throw new ResourceInitializationException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
Integer beamSize = AnnotatorUtil.getOptionalIntegerParameter(context, UimaUtil.BEAM_SIZE_PARAMETER);
|
|
||||||
|
|
||||||
if (beamSize == null)
|
|
||||||
beamSize = POSTaggerME.DEFAULT_BEAM_SIZE;
|
|
||||||
|
|
||||||
this.posTagger = new POSTaggerME(model, beamSize, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes the type system.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
|
|
||||||
|
|
||||||
// sentence type
|
|
||||||
this.sentenceType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
|
|
||||||
UimaUtil.SENTENCE_TYPE_PARAMETER);
|
|
||||||
|
|
||||||
// token type
|
|
||||||
this.tokenType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
|
|
||||||
UimaUtil.TOKEN_TYPE_PARAMETER);
|
|
||||||
|
|
||||||
// pos feature
|
|
||||||
this.posFeature = AnnotatorUtil.getRequiredFeatureParameter(this.context, this.tokenType,
|
|
||||||
UimaUtil.POS_FEATURE_PARAMETER, CAS.TYPE_NAME_STRING);
|
|
||||||
|
|
||||||
this.probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(this.context, this.tokenType,
|
|
||||||
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs pos-tagging on the given tcas object.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public synchronized void process(CAS tcas) {
|
|
||||||
|
|
||||||
final AnnotationComboIterator comboIterator =
|
|
||||||
new AnnotationComboIterator(tcas, this.sentenceType, this.tokenType);
|
|
||||||
|
|
||||||
for (AnnotationIteratorPair annotationIteratorPair : comboIterator) {
|
|
||||||
|
|
||||||
final List<AnnotationFS> sentenceTokenAnnotationList = new LinkedList<AnnotationFS>();
|
|
||||||
|
|
||||||
final List<String> sentenceTokenList = new LinkedList<String>();
|
|
||||||
|
|
||||||
for (AnnotationFS tokenAnnotation : annotationIteratorPair.getSubIterator()) {
|
|
||||||
|
|
||||||
sentenceTokenAnnotationList.add(tokenAnnotation);
|
|
||||||
|
|
||||||
sentenceTokenList.add(tokenAnnotation.getCoveredText());
|
|
||||||
}
|
|
||||||
|
|
||||||
final List<String> posTags = this.posTagger.tag(sentenceTokenList);
|
|
||||||
|
|
||||||
double posProbabilities[] = null;
|
|
||||||
|
|
||||||
if (this.probabilityFeature != null) {
|
|
||||||
posProbabilities = this.posTagger.probs();
|
|
||||||
}
|
|
||||||
|
|
||||||
final Iterator<String> posTagIterator = posTags.iterator();
|
|
||||||
final Iterator<AnnotationFS> sentenceTokenIterator = sentenceTokenAnnotationList.iterator();
|
|
||||||
|
|
||||||
int index = 0;
|
|
||||||
while (posTagIterator.hasNext() && sentenceTokenIterator.hasNext()) {
|
|
||||||
final String posTag = posTagIterator.next();
|
|
||||||
final AnnotationFS tokenAnnotation = sentenceTokenIterator.next();
|
|
||||||
|
|
||||||
tokenAnnotation.setStringValue(this.posFeature, posTag);
|
|
||||||
|
|
||||||
if (posProbabilities != null) {
|
|
||||||
tokenAnnotation.setDoubleValue(this.posFeature, posProbabilities[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// log tokens with pos
|
|
||||||
if (this.logger.isLoggable(Level.FINER)) {
|
|
||||||
|
|
||||||
final StringBuilder sentenceWithPos = new StringBuilder();
|
|
||||||
|
|
||||||
sentenceWithPos.append("\"");
|
|
||||||
|
|
||||||
for (final Iterator<AnnotationFS> it = sentenceTokenAnnotationList.iterator(); it.hasNext();) {
|
|
||||||
final AnnotationFS token = it.next();
|
|
||||||
sentenceWithPos.append(token.getCoveredText());
|
|
||||||
sentenceWithPos.append('\\');
|
|
||||||
sentenceWithPos.append(token.getStringValue(this.posFeature));
|
|
||||||
sentenceWithPos.append(' ');
|
|
||||||
}
|
|
||||||
|
|
||||||
// delete last whitespace
|
|
||||||
if (sentenceWithPos.length() > 1) // not 0 because it contains already the " char
|
|
||||||
sentenceWithPos.setLength(sentenceWithPos.length() - 1);
|
|
||||||
|
|
||||||
sentenceWithPos.append("\"");
|
|
||||||
|
|
||||||
this.logger.log(Level.FINER, sentenceWithPos.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Releases allocated resources.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void destroy() {
|
|
||||||
this.posTagger = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
|
|
||||||
String modelPath = String.format("/models/%s-pos-maxent.bin", languageCode);
|
|
||||||
return AnalysisEngineFactory.createEngineDescription(PoStagger.class, UimaUtil.MODEL_PARAMETER,
|
|
||||||
ExternalResourceFactory.createExternalResourceDescription(POSModelResourceImpl.class,
|
|
||||||
PoStagger.class.getResource(modelPath).toString()),
|
|
||||||
UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), UimaUtil.TOKEN_TYPE_PARAMETER,
|
|
||||||
Token.class.getName(), UimaUtil.POS_FEATURE_PARAMETER, "pos");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,52 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.annotator;
|
|
||||||
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
|
||||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
|
||||||
import org.apache.uima.jcas.JCas;
|
|
||||||
import org.apache.uima.resource.ResourceInitializationException;
|
|
||||||
import org.cleartk.util.ParamUtil;
|
|
||||||
import org.datavec.nlp.movingwindow.Util;
|
|
||||||
|
|
||||||
public class SentenceAnnotator extends org.cleartk.opennlp.tools.SentenceAnnotator {
|
|
||||||
|
|
||||||
static {
|
|
||||||
//UIMA logging
|
|
||||||
Util.disableLogging();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
|
|
||||||
return AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.class, PARAM_SENTENCE_MODEL_PATH,
|
|
||||||
ParamUtil.getParameterValue(PARAM_SENTENCE_MODEL_PATH, "/models/en-sent.bin"),
|
|
||||||
PARAM_WINDOW_CLASS_NAMES, ParamUtil.getParameterValue(PARAM_WINDOW_CLASS_NAMES, null));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
|
|
||||||
super.process(jCas);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,58 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.annotator;
|
|
||||||
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
|
||||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
|
||||||
import org.apache.uima.jcas.JCas;
|
|
||||||
import org.apache.uima.resource.ResourceInitializationException;
|
|
||||||
import org.cleartk.snowball.SnowballStemmer;
|
|
||||||
import org.cleartk.token.type.Token;
|
|
||||||
|
|
||||||
|
|
||||||
public class StemmerAnnotator extends SnowballStemmer<Token> {
|
|
||||||
|
|
||||||
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
|
|
||||||
return getDescription("English");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static AnalysisEngineDescription getDescription(String language) throws ResourceInitializationException {
|
|
||||||
return AnalysisEngineFactory.createEngineDescription(StemmerAnnotator.class, SnowballStemmer.PARAM_STEMMER_NAME,
|
|
||||||
language);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
@Override
|
|
||||||
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
|
|
||||||
super.process(jCas);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setStem(Token token, String stem) {
|
|
||||||
token.setStem(stem);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,70 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.annotator;
|
|
||||||
|
|
||||||
|
|
||||||
import opennlp.uima.tokenize.TokenizerModelResourceImpl;
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
|
|
||||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
|
||||||
import org.apache.uima.fit.factory.ExternalResourceFactory;
|
|
||||||
import org.apache.uima.resource.ResourceInitializationException;
|
|
||||||
import org.cleartk.opennlp.tools.Tokenizer;
|
|
||||||
import org.cleartk.token.type.Sentence;
|
|
||||||
import org.cleartk.token.type.Token;
|
|
||||||
import org.datavec.nlp.movingwindow.Util;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.ConcurrentTokenizer;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Overrides OpenNLP tokenizer to be thread safe
|
|
||||||
*/
|
|
||||||
public class TokenizerAnnotator extends Tokenizer {
|
|
||||||
|
|
||||||
static {
|
|
||||||
//UIMA logging
|
|
||||||
Util.disableLogging();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
|
|
||||||
String modelPath = String.format("/models/%s-token.bin", languageCode);
|
|
||||||
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
|
|
||||||
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
|
|
||||||
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
|
|
||||||
ConcurrentTokenizer.class.getResource(modelPath).toString()),
|
|
||||||
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
|
|
||||||
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
|
|
||||||
String modelPath = String.format("/models/%s-token.bin", "en");
|
|
||||||
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
|
|
||||||
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
|
|
||||||
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
|
|
||||||
ConcurrentTokenizer.class.getResource(modelPath).toString()),
|
|
||||||
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
|
|
||||||
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.input;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.formats.input.BaseInputFormat;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.nlp.reader.TfidfRecordReader;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class TextInputFormat extends BaseInputFormat {
|
|
||||||
@Override
|
|
||||||
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
|
|
||||||
RecordReader reader = new TfidfRecordReader();
|
|
||||||
reader.initialize(conf, split);
|
|
||||||
return reader;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,148 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.metadata;
|
|
||||||
|
|
||||||
import org.nd4j.common.primitives.Counter;
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.nlp.vectorizer.TextVectorizer;
|
|
||||||
import org.nd4j.common.util.MathUtils;
|
|
||||||
import org.nd4j.common.util.Index;
|
|
||||||
|
|
||||||
public class DefaultVocabCache implements VocabCache {
|
|
||||||
|
|
||||||
private Counter<String> wordFrequencies = new Counter<>();
|
|
||||||
private Counter<String> docFrequencies = new Counter<>();
|
|
||||||
private int minWordFrequency;
|
|
||||||
private Index vocabWords = new Index();
|
|
||||||
private double numDocs = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Instantiate with a given min word frequency
|
|
||||||
* @param minWordFrequency
|
|
||||||
*/
|
|
||||||
public DefaultVocabCache(int minWordFrequency) {
|
|
||||||
this.minWordFrequency = minWordFrequency;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Constructor for use with initialize()
|
|
||||||
*/
|
|
||||||
public DefaultVocabCache() {
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void incrementNumDocs(double by) {
|
|
||||||
numDocs += by;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double numDocs() {
|
|
||||||
return numDocs;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String wordAt(int i) {
|
|
||||||
return vocabWords.get(i).toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int wordIndex(String word) {
|
|
||||||
return vocabWords.indexOf(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(Configuration conf) {
|
|
||||||
minWordFrequency = conf.getInt(TextVectorizer.MIN_WORD_FREQUENCY, 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double wordFrequency(String word) {
|
|
||||||
return wordFrequencies.getCount(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int minWordFrequency() {
|
|
||||||
return minWordFrequency;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Index vocabWords() {
|
|
||||||
return vocabWords;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void incrementDocCount(String word) {
|
|
||||||
incrementDocCount(word, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void incrementDocCount(String word, double by) {
|
|
||||||
docFrequencies.incrementCount(word, by);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void incrementCount(String word) {
|
|
||||||
incrementCount(word, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void incrementCount(String word, double by) {
|
|
||||||
wordFrequencies.incrementCount(word, by);
|
|
||||||
if (wordFrequencies.getCount(word) >= minWordFrequency && vocabWords.indexOf(word) < 0)
|
|
||||||
vocabWords.add(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double idf(String word) {
|
|
||||||
return docFrequencies.getCount(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double tfidf(String word, double frequency, boolean smoothIdf) {
|
|
||||||
double tf = tf((int) frequency);
|
|
||||||
double docFreq = docFrequencies.getCount(word);
|
|
||||||
|
|
||||||
double idf = idf(numDocs, docFreq, smoothIdf);
|
|
||||||
double tfidf = MathUtils.tfidf(tf, idf);
|
|
||||||
return tfidf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double idf(double totalDocs, double numTimesWordAppearedInADocument, boolean smooth) {
|
|
||||||
if(smooth){
|
|
||||||
return Math.log((1 + totalDocs) / (1 + numTimesWordAppearedInADocument)) + 1.0;
|
|
||||||
} else {
|
|
||||||
return Math.log(totalDocs / numTimesWordAppearedInADocument) + 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double tf(int count) {
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMinWordFrequency() {
|
|
||||||
return minWordFrequency;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setMinWordFrequency(int minWordFrequency) {
|
|
||||||
this.minWordFrequency = minWordFrequency;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,121 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.metadata;
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.nd4j.common.util.Index;
|
|
||||||
|
|
||||||
public interface VocabCache {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increment the number of documents
|
|
||||||
* @param by
|
|
||||||
*/
|
|
||||||
void incrementNumDocs(double by);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of documents
|
|
||||||
* @return the number of documents
|
|
||||||
*/
|
|
||||||
double numDocs();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a word in the vocab at a particular index
|
|
||||||
* @param i the index to get
|
|
||||||
* @return the word at that index in the vocab
|
|
||||||
*/
|
|
||||||
String wordAt(int i);
|
|
||||||
|
|
||||||
int wordIndex(String word);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configuration for initializing
|
|
||||||
* @param conf the configuration to initialize with
|
|
||||||
*/
|
|
||||||
void initialize(Configuration conf);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the word frequency for a word
|
|
||||||
* @param word the word to get frequency for
|
|
||||||
* @return the frequency for a given word
|
|
||||||
*/
|
|
||||||
double wordFrequency(String word);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The min word frequency
|
|
||||||
* needed to be included in the vocab
|
|
||||||
* (default 5)
|
|
||||||
* @return the min word frequency to
|
|
||||||
* be included in the vocab
|
|
||||||
*/
|
|
||||||
int minWordFrequency();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* All of the vocab words (ordered)
|
|
||||||
* note that these are not all the possible tokens
|
|
||||||
* @return the list of vocab words
|
|
||||||
*/
|
|
||||||
Index vocabWords();
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increment the doc count for a word by 1
|
|
||||||
* @param word the word to increment the count for
|
|
||||||
*/
|
|
||||||
void incrementDocCount(String word);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increment the document count for a particular word
|
|
||||||
* @param word the word to increment the count for
|
|
||||||
* @param by the amount to increment by
|
|
||||||
*/
|
|
||||||
void incrementDocCount(String word, double by);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increment a word count by 1
|
|
||||||
* @param word the word to increment the count for
|
|
||||||
*/
|
|
||||||
void incrementCount(String word);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increment count for a word
|
|
||||||
* @param word the word to increment the count for
|
|
||||||
* @param by the amount to increment by
|
|
||||||
*/
|
|
||||||
void incrementCount(String word, double by);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of documents word has occurred in
|
|
||||||
* @param word the word to get the idf for
|
|
||||||
*/
|
|
||||||
double idf(String word);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the tfidf of the word given the document frequency
|
|
||||||
* @param word the word to get frequency for
|
|
||||||
* @param frequency the frequency
|
|
||||||
* @return the tfidf for a word
|
|
||||||
*/
|
|
||||||
double tfidf(String word, double frequency, boolean smoothIdf);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,125 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.movingwindow;
|
|
||||||
|
|
||||||
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
import org.nd4j.common.collection.MultiDimensionalMap;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ContextLabelRetriever {
|
|
||||||
|
|
||||||
|
|
||||||
private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>";
|
|
||||||
private static String END_LABEL = "</([A-Za-z]+|\\d+)>";
|
|
||||||
|
|
||||||
|
|
||||||
private ContextLabelRetriever() {}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a stripped sentence with the indices of words
|
|
||||||
* with certain kinds of labels.
|
|
||||||
*
|
|
||||||
* @param sentence the sentence to process
|
|
||||||
* @return a pair of a post processed sentence
|
|
||||||
* with labels stripped and the spans of
|
|
||||||
* the labels
|
|
||||||
*/
|
|
||||||
public static Pair<String, MultiDimensionalMap<Integer, Integer, String>> stringWithLabels(String sentence,
|
|
||||||
TokenizerFactory tokenizerFactory) {
|
|
||||||
MultiDimensionalMap<Integer, Integer, String> map = MultiDimensionalMap.newHashBackedMap();
|
|
||||||
Tokenizer t = tokenizerFactory.create(sentence);
|
|
||||||
List<String> currTokens = new ArrayList<>();
|
|
||||||
String currLabel = null;
|
|
||||||
String endLabel = null;
|
|
||||||
List<Pair<String, List<String>>> tokensWithSameLabel = new ArrayList<>();
|
|
||||||
while (t.hasMoreTokens()) {
|
|
||||||
String token = t.nextToken();
|
|
||||||
if (token.matches(BEGIN_LABEL)) {
|
|
||||||
currLabel = token;
|
|
||||||
|
|
||||||
//no labels; add these as NONE and begin the new label
|
|
||||||
if (!currTokens.isEmpty()) {
|
|
||||||
tokensWithSameLabel.add(new Pair<>("NONE", (List<String>) new ArrayList<>(currTokens)));
|
|
||||||
currTokens.clear();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (token.matches(END_LABEL)) {
|
|
||||||
if (currLabel == null)
|
|
||||||
throw new IllegalStateException("Found an ending label with no matching begin label");
|
|
||||||
endLabel = token;
|
|
||||||
} else
|
|
||||||
currTokens.add(token);
|
|
||||||
|
|
||||||
if (currLabel != null && endLabel != null) {
|
|
||||||
currLabel = currLabel.replaceAll("[<>/]", "");
|
|
||||||
endLabel = endLabel.replaceAll("[<>/]", "");
|
|
||||||
Preconditions.checkState(!currLabel.isEmpty(), "Current label is empty!");
|
|
||||||
Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!");
|
|
||||||
Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s",
|
|
||||||
currLabel, endLabel);
|
|
||||||
|
|
||||||
tokensWithSameLabel.add(new Pair<>(currLabel, (List<String>) new ArrayList<>(currTokens)));
|
|
||||||
currTokens.clear();
|
|
||||||
|
|
||||||
|
|
||||||
//clear out the tokens
|
|
||||||
currLabel = null;
|
|
||||||
endLabel = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//no labels; add these as NONE and begin the new label
|
|
||||||
if (!currTokens.isEmpty()) {
|
|
||||||
tokensWithSameLabel.add(new Pair<>("none", (List<String>) new ArrayList<>(currTokens)));
|
|
||||||
currTokens.clear();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//now join the output
|
|
||||||
StringBuilder strippedSentence = new StringBuilder();
|
|
||||||
for (Pair<String, List<String>> tokensWithLabel : tokensWithSameLabel) {
|
|
||||||
String joinedSentence = StringUtils.join(tokensWithLabel.getSecond(), " ");
|
|
||||||
//spaces between separate parts of the sentence
|
|
||||||
if (!(strippedSentence.length() < 1))
|
|
||||||
strippedSentence.append(" ");
|
|
||||||
strippedSentence.append(joinedSentence);
|
|
||||||
int begin = strippedSentence.toString().indexOf(joinedSentence);
|
|
||||||
int end = begin + joinedSentence.length();
|
|
||||||
map.put(begin, end, tokensWithLabel.getFirst());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return new Pair<>(strippedSentence.toString(), map);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.movingwindow;
|
|
||||||
|
|
||||||
|
|
||||||
import org.nd4j.common.primitives.Counter;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.logging.Level;
|
|
||||||
import java.util.logging.Logger;
|
|
||||||
|
|
||||||
|
|
||||||
public class Util {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a thread safe counter
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static Counter<String> parallelCounter() {
|
|
||||||
return new Counter<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean matchesAnyStopWord(List<String> stopWords, String word) {
|
|
||||||
for (String s : stopWords)
|
|
||||||
if (s.equalsIgnoreCase(word))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Level disableLogging() {
|
|
||||||
Logger logger = Logger.getLogger("org.apache.uima");
|
|
||||||
while (logger.getLevel() == null) {
|
|
||||||
logger = logger.getParent();
|
|
||||||
}
|
|
||||||
Level level = logger.getLevel();
|
|
||||||
logger.setLevel(Level.OFF);
|
|
||||||
return level;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,177 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.movingwindow;
|
|
||||||
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
public class Window implements Serializable {
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private static final long serialVersionUID = 6359906393699230579L;
|
|
||||||
private List<String> words;
|
|
||||||
private String label = "NONE";
|
|
||||||
private boolean beginLabel;
|
|
||||||
private boolean endLabel;
|
|
||||||
private int median;
|
|
||||||
private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>";
|
|
||||||
private static String END_LABEL = "</([A-Z]+|\\d+)>";
|
|
||||||
private int begin, end;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a window with a context of size 3
|
|
||||||
* @param words a collection of strings of size 3
|
|
||||||
*/
|
|
||||||
public Window(Collection<String> words, int begin, int end) {
|
|
||||||
this(words, 5, begin, end);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public String asTokens() {
|
|
||||||
return StringUtils.join(words, " ");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize a window with the given size
|
|
||||||
* @param words the words to use
|
|
||||||
* @param windowSize the size of the window
|
|
||||||
* @param begin the begin index for the window
|
|
||||||
* @param end the end index for the window
|
|
||||||
*/
|
|
||||||
public Window(Collection<String> words, int windowSize, int begin, int end) {
|
|
||||||
if (words == null)
|
|
||||||
throw new IllegalArgumentException("Words must be a list of size 3");
|
|
||||||
|
|
||||||
this.words = new ArrayList<>(words);
|
|
||||||
int windowSize1 = windowSize;
|
|
||||||
this.begin = begin;
|
|
||||||
this.end = end;
|
|
||||||
initContext();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private void initContext() {
|
|
||||||
int median = (int) Math.floor(words.size() / 2);
|
|
||||||
List<String> begin = words.subList(0, median);
|
|
||||||
List<String> after = words.subList(median + 1, words.size());
|
|
||||||
|
|
||||||
|
|
||||||
for (String s : begin) {
|
|
||||||
if (s.matches(BEGIN_LABEL)) {
|
|
||||||
this.label = s.replaceAll("(<|>)", "").replace("/", "");
|
|
||||||
beginLabel = true;
|
|
||||||
} else if (s.matches(END_LABEL)) {
|
|
||||||
endLabel = true;
|
|
||||||
this.label = s.replaceAll("(<|>|/)", "").replace("/", "");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
for (String s1 : after) {
|
|
||||||
|
|
||||||
if (s1.matches(BEGIN_LABEL)) {
|
|
||||||
this.label = s1.replaceAll("(<|>)", "").replace("/", "");
|
|
||||||
beginLabel = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s1.matches(END_LABEL)) {
|
|
||||||
endLabel = true;
|
|
||||||
this.label = s1.replaceAll("(<|>)", "");
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.median = median;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return words.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getWords() {
|
|
||||||
return words;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setWords(List<String> words) {
|
|
||||||
this.words = words;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getWord(int i) {
|
|
||||||
return words.get(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getFocusWord() {
|
|
||||||
return words.get(median);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isBeginLabel() {
|
|
||||||
return !label.equals("NONE") && beginLabel;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isEndLabel() {
|
|
||||||
return !label.equals("NONE") && endLabel;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getLabel() {
|
|
||||||
return label.replace("/", "");
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getWindowSize() {
|
|
||||||
return words.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMedian() {
|
|
||||||
return median;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setLabel(String label) {
|
|
||||||
this.label = label;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getBegin() {
|
|
||||||
return begin;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setBegin(int begin) {
|
|
||||||
this.begin = begin;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getEnd() {
|
|
||||||
return end;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setEnd(int end) {
|
|
||||||
this.end = end;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,188 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.movingwindow;
|
|
||||||
|
|
||||||
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.StringTokenizer;
|
|
||||||
|
|
||||||
public class Windows {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs a list of window of size windowSize.
|
|
||||||
* Note that padding for each window is created as well.
|
|
||||||
* @param words the words to tokenize and construct windows from
|
|
||||||
* @param windowSize the window size to generate
|
|
||||||
* @return the list of windows for the tokenized string
|
|
||||||
*/
|
|
||||||
public static List<Window> windows(InputStream words, int windowSize) {
|
|
||||||
Tokenizer tokenizer = new DefaultStreamTokenizer(words);
|
|
||||||
List<String> list = new ArrayList<>();
|
|
||||||
while (tokenizer.hasMoreTokens())
|
|
||||||
list.add(tokenizer.nextToken());
|
|
||||||
return windows(list, windowSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs a list of window of size windowSize.
|
|
||||||
* Note that padding for each window is created as well.
|
|
||||||
* @param words the words to tokenize and construct windows from
|
|
||||||
* @param tokenizerFactory tokenizer factory to use
|
|
||||||
* @param windowSize the window size to generate
|
|
||||||
* @return the list of windows for the tokenized string
|
|
||||||
*/
|
|
||||||
public static List<Window> windows(InputStream words, TokenizerFactory tokenizerFactory, int windowSize) {
|
|
||||||
Tokenizer tokenizer = tokenizerFactory.create(words);
|
|
||||||
List<String> list = new ArrayList<>();
|
|
||||||
while (tokenizer.hasMoreTokens())
|
|
||||||
list.add(tokenizer.nextToken());
|
|
||||||
|
|
||||||
if (list.isEmpty())
|
|
||||||
throw new IllegalStateException("No tokens found for windows");
|
|
||||||
|
|
||||||
return windows(list, windowSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs a list of window of size windowSize.
|
|
||||||
* Note that padding for each window is created as well.
|
|
||||||
* @param words the words to tokenize and construct windows from
|
|
||||||
* @param windowSize the window size to generate
|
|
||||||
* @return the list of windows for the tokenized string
|
|
||||||
*/
|
|
||||||
public static List<Window> windows(String words, int windowSize) {
|
|
||||||
StringTokenizer tokenizer = new StringTokenizer(words);
|
|
||||||
List<String> list = new ArrayList<String>();
|
|
||||||
while (tokenizer.hasMoreTokens())
|
|
||||||
list.add(tokenizer.nextToken());
|
|
||||||
return windows(list, windowSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs a list of window of size windowSize.
|
|
||||||
* Note that padding for each window is created as well.
|
|
||||||
* @param words the words to tokenize and construct windows from
|
|
||||||
* @param tokenizerFactory tokenizer factory to use
|
|
||||||
* @param windowSize the window size to generate
|
|
||||||
* @return the list of windows for the tokenized string
|
|
||||||
*/
|
|
||||||
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory, int windowSize) {
|
|
||||||
Tokenizer tokenizer = tokenizerFactory.create(words);
|
|
||||||
List<String> list = new ArrayList<>();
|
|
||||||
while (tokenizer.hasMoreTokens())
|
|
||||||
list.add(tokenizer.nextToken());
|
|
||||||
|
|
||||||
if (list.isEmpty())
|
|
||||||
throw new IllegalStateException("No tokens found for windows");
|
|
||||||
|
|
||||||
return windows(list, windowSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs a list of window of size windowSize.
|
|
||||||
* Note that padding for each window is created as well.
|
|
||||||
* @param words the words to tokenize and construct windows from
|
|
||||||
* @return the list of windows for the tokenized string
|
|
||||||
*/
|
|
||||||
public static List<Window> windows(String words) {
|
|
||||||
StringTokenizer tokenizer = new StringTokenizer(words);
|
|
||||||
List<String> list = new ArrayList<String>();
|
|
||||||
while (tokenizer.hasMoreTokens())
|
|
||||||
list.add(tokenizer.nextToken());
|
|
||||||
return windows(list, 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs a list of window of size windowSize.
|
|
||||||
* Note that padding for each window is created as well.
|
|
||||||
* @param words the words to tokenize and construct windows from
|
|
||||||
* @param tokenizerFactory tokenizer factory to use
|
|
||||||
* @return the list of windows for the tokenized string
|
|
||||||
*/
|
|
||||||
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory) {
|
|
||||||
Tokenizer tokenizer = tokenizerFactory.create(words);
|
|
||||||
List<String> list = new ArrayList<>();
|
|
||||||
while (tokenizer.hasMoreTokens())
|
|
||||||
list.add(tokenizer.nextToken());
|
|
||||||
return windows(list, 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a sliding window from text
|
|
||||||
* @param windowSize the window size to use
|
|
||||||
* @param wordPos the position of the word to center
|
|
||||||
* @param sentence the sentence to createComplex a window for
|
|
||||||
* @return a window based on the given sentence
|
|
||||||
*/
|
|
||||||
public static Window windowForWordInPosition(int windowSize, int wordPos, List<String> sentence) {
|
|
||||||
List<String> window = new ArrayList<>();
|
|
||||||
List<String> onlyTokens = new ArrayList<>();
|
|
||||||
int contextSize = (int) Math.floor((windowSize - 1) / 2);
|
|
||||||
|
|
||||||
for (int i = wordPos - contextSize; i <= wordPos + contextSize; i++) {
|
|
||||||
if (i < 0)
|
|
||||||
window.add("<s>");
|
|
||||||
else if (i >= sentence.size())
|
|
||||||
window.add("</s>");
|
|
||||||
else {
|
|
||||||
onlyTokens.add(sentence.get(i));
|
|
||||||
window.add(sentence.get(i));
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
String wholeSentence = StringUtils.join(sentence);
|
|
||||||
String window2 = StringUtils.join(onlyTokens);
|
|
||||||
int begin = wholeSentence.indexOf(window2);
|
|
||||||
int end = begin + window2.length();
|
|
||||||
return new Window(window, begin, end);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructs a list of window of size windowSize
|
|
||||||
* @param words the words to construct windows from
|
|
||||||
* @return the list of windows for the tokenized string
|
|
||||||
*/
|
|
||||||
public static List<Window> windows(List<String> words, int windowSize) {
|
|
||||||
|
|
||||||
List<Window> ret = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < words.size(); i++)
|
|
||||||
ret.add(windowForWordInPosition(windowSize, i, words));
|
|
||||||
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,189 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.reader;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.Record;
|
|
||||||
import org.datavec.api.records.metadata.RecordMetaData;
|
|
||||||
import org.datavec.api.records.metadata.RecordMetaDataURI;
|
|
||||||
import org.datavec.api.records.reader.impl.FileRecordReader;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.api.vector.Vectorizer;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.nlp.vectorizer.TfidfVectorizer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
public class TfidfRecordReader extends FileRecordReader {
|
|
||||||
private TfidfVectorizer tfidfVectorizer;
|
|
||||||
private List<Record> records = new ArrayList<>();
|
|
||||||
private Iterator<Record> recordIter;
|
|
||||||
private int numFeatures;
|
|
||||||
private boolean initialized = false;
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(InputSplit split) throws IOException, InterruptedException {
|
|
||||||
initialize(new Configuration(), split);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
|
||||||
super.initialize(conf, split);
|
|
||||||
//train a new one since it hasn't been specified
|
|
||||||
if (tfidfVectorizer == null) {
|
|
||||||
tfidfVectorizer = new TfidfVectorizer();
|
|
||||||
tfidfVectorizer.initialize(conf);
|
|
||||||
|
|
||||||
//clear out old strings
|
|
||||||
records.clear();
|
|
||||||
|
|
||||||
INDArray ret = tfidfVectorizer.fitTransform(this, new Vectorizer.RecordCallBack() {
|
|
||||||
@Override
|
|
||||||
public void onRecord(Record fullRecord) {
|
|
||||||
records.add(fullRecord);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
//cache the number of features used for each document
|
|
||||||
numFeatures = ret.columns();
|
|
||||||
recordIter = records.iterator();
|
|
||||||
} else {
|
|
||||||
records = new ArrayList<>();
|
|
||||||
|
|
||||||
//the record reader has 2 phases, we are skipping the
|
|
||||||
//document frequency phase and just using the super() to get the file contents
|
|
||||||
//and pass it to the already existing vectorizer.
|
|
||||||
while (super.hasNext()) {
|
|
||||||
Record fileContents = super.nextRecord();
|
|
||||||
INDArray transform = tfidfVectorizer.transform(fileContents);
|
|
||||||
|
|
||||||
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
|
|
||||||
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
|
|
||||||
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
|
|
||||||
|
|
||||||
if (appendLabel)
|
|
||||||
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
|
|
||||||
|
|
||||||
records.add(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
recordIter = records.iterator();
|
|
||||||
}
|
|
||||||
|
|
||||||
this.initialized = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
if (inputSplit == null)
|
|
||||||
throw new UnsupportedOperationException("Cannot reset without first initializing");
|
|
||||||
recordIter = records.iterator();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Record nextRecord() {
|
|
||||||
if (recordIter == null)
|
|
||||||
return super.nextRecord();
|
|
||||||
return recordIter.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Writable> next() {
|
|
||||||
return nextRecord().getRecord();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasNext() {
|
|
||||||
//we aren't done vectorizing yet
|
|
||||||
if (recordIter == null)
|
|
||||||
return super.hasNext();
|
|
||||||
return recordIter.hasNext();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setConf(Configuration conf) {
|
|
||||||
this.conf = conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Configuration getConf() {
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public TfidfVectorizer getTfidfVectorizer() {
|
|
||||||
return tfidfVectorizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) {
|
|
||||||
if (initialized) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Setting TfidfVectorizer after TfidfRecordReader initialization doesn't have an effect");
|
|
||||||
}
|
|
||||||
this.tfidfVectorizer = tfidfVectorizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumFeatures() {
|
|
||||||
return numFeatures;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void shuffle() {
|
|
||||||
this.shuffle(new Random());
|
|
||||||
}
|
|
||||||
|
|
||||||
public void shuffle(Random random) {
|
|
||||||
Collections.shuffle(this.records, random);
|
|
||||||
this.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
|
|
||||||
return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
|
|
||||||
List<Record> out = new ArrayList<>();
|
|
||||||
|
|
||||||
for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) {
|
|
||||||
INDArray transform = tfidfVectorizer.transform(fileContents);
|
|
||||||
|
|
||||||
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
|
|
||||||
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
|
|
||||||
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
|
|
||||||
|
|
||||||
if (appendLabel)
|
|
||||||
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
|
|
||||||
out.add(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.stopwords;
|
|
||||||
|
|
||||||
import org.apache.commons.io.IOUtils;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class StopWords {
|
|
||||||
|
|
||||||
private static List<String> stopWords;
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public static List<String> getStopWords() {
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (stopWords == null)
|
|
||||||
stopWords = IOUtils.readLines(StopWords.class.getResourceAsStream("/stopwords"));
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
return stopWords;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,125 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer;
|
|
||||||
|
|
||||||
import opennlp.tools.tokenize.TokenizerME;
|
|
||||||
import opennlp.tools.tokenize.TokenizerModel;
|
|
||||||
import opennlp.tools.util.Span;
|
|
||||||
import opennlp.uima.tokenize.AbstractTokenizer;
|
|
||||||
import opennlp.uima.tokenize.TokenizerModelResource;
|
|
||||||
import opennlp.uima.util.AnnotatorUtil;
|
|
||||||
import opennlp.uima.util.UimaUtil;
|
|
||||||
import org.apache.uima.UimaContext;
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
|
||||||
import org.apache.uima.cas.CAS;
|
|
||||||
import org.apache.uima.cas.Feature;
|
|
||||||
import org.apache.uima.cas.TypeSystem;
|
|
||||||
import org.apache.uima.cas.text.AnnotationFS;
|
|
||||||
import org.apache.uima.resource.ResourceAccessException;
|
|
||||||
import org.apache.uima.resource.ResourceInitializationException;
|
|
||||||
|
|
||||||
public class ConcurrentTokenizer extends AbstractTokenizer {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The OpenNLP tokenizer.
|
|
||||||
*/
|
|
||||||
private TokenizerME tokenizer;
|
|
||||||
|
|
||||||
private Feature probabilityFeature;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public synchronized void process(CAS cas) throws AnalysisEngineProcessException {
|
|
||||||
super.process(cas);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes a new instance.
|
|
||||||
*
|
|
||||||
* Note: Use {@link #initialize(UimaContext) } to initialize
|
|
||||||
* this instance. Not use the constructor.
|
|
||||||
*/
|
|
||||||
public ConcurrentTokenizer() {
|
|
||||||
super("OpenNLP Tokenizer");
|
|
||||||
|
|
||||||
// must not be implemented !
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes the current instance with the given context.
|
|
||||||
*
|
|
||||||
* Note: Do all initialization in this method, do not use the constructor.
|
|
||||||
*/
|
|
||||||
public void initialize(UimaContext context) throws ResourceInitializationException {
|
|
||||||
|
|
||||||
super.initialize(context);
|
|
||||||
|
|
||||||
TokenizerModel model;
|
|
||||||
|
|
||||||
try {
|
|
||||||
TokenizerModelResource modelResource =
|
|
||||||
(TokenizerModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
|
|
||||||
|
|
||||||
model = modelResource.getModel();
|
|
||||||
} catch (ResourceAccessException e) {
|
|
||||||
throw new ResourceInitializationException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenizer = new TokenizerME(model);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes the type system.
|
|
||||||
*/
|
|
||||||
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
|
|
||||||
|
|
||||||
super.typeSystemInit(typeSystem);
|
|
||||||
|
|
||||||
probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(context, tokenType,
|
|
||||||
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Span[] tokenize(CAS cas, AnnotationFS sentence) {
|
|
||||||
return tokenizer.tokenizePos(sentence.getCoveredText());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) {
|
|
||||||
// if interest
|
|
||||||
if (probabilityFeature != null) {
|
|
||||||
double tokenProbabilties[] = tokenizer.getTokenProbabilities();
|
|
||||||
|
|
||||||
for (int i = 0; i < tokenAnnotations.length; i++) {
|
|
||||||
tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Releases allocated resources.
|
|
||||||
*/
|
|
||||||
public void destroy() {
|
|
||||||
// dereference model to allow garbage collection
|
|
||||||
tokenizer = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,107 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer;
|
|
||||||
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tokenizer based on the {@link java.io.StreamTokenizer}
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class DefaultStreamTokenizer implements Tokenizer {
|
|
||||||
|
|
||||||
private StreamTokenizer streamTokenizer;
|
|
||||||
private TokenPreProcess tokenPreProcess;
|
|
||||||
|
|
||||||
|
|
||||||
public DefaultStreamTokenizer(InputStream is) {
|
|
||||||
Reader r = new BufferedReader(new InputStreamReader(is));
|
|
||||||
streamTokenizer = new StreamTokenizer(r);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasMoreTokens() {
|
|
||||||
if (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
|
|
||||||
try {
|
|
||||||
streamTokenizer.nextToken();
|
|
||||||
} catch (IOException e1) {
|
|
||||||
throw new RuntimeException(e1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return streamTokenizer.ttype != StreamTokenizer.TT_EOF && streamTokenizer.ttype != -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int countTokens() {
|
|
||||||
return getTokens().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String nextToken() {
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
|
|
||||||
|
|
||||||
if (streamTokenizer.ttype == StreamTokenizer.TT_WORD) {
|
|
||||||
sb.append(streamTokenizer.sval);
|
|
||||||
} else if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
|
|
||||||
sb.append(streamTokenizer.nval);
|
|
||||||
} else if (streamTokenizer.ttype == StreamTokenizer.TT_EOL) {
|
|
||||||
try {
|
|
||||||
while (streamTokenizer.ttype == StreamTokenizer.TT_EOL)
|
|
||||||
streamTokenizer.nextToken();
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
else if (hasMoreTokens())
|
|
||||||
return nextToken();
|
|
||||||
|
|
||||||
|
|
||||||
String ret = sb.toString();
|
|
||||||
|
|
||||||
if (tokenPreProcess != null)
|
|
||||||
ret = tokenPreProcess.preProcess(ret);
|
|
||||||
return ret;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getTokens() {
|
|
||||||
List<String> tokens = new ArrayList<>();
|
|
||||||
while (hasMoreTokens()) {
|
|
||||||
tokens.add(nextToken());
|
|
||||||
}
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
|
||||||
this.tokenPreProcess = tokenPreProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,75 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.StringTokenizer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Default tokenizer
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class DefaultTokenizer implements Tokenizer {
|
|
||||||
|
|
||||||
public DefaultTokenizer(String tokens) {
|
|
||||||
tokenizer = new StringTokenizer(tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
private StringTokenizer tokenizer;
|
|
||||||
private TokenPreProcess tokenPreProcess;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasMoreTokens() {
|
|
||||||
return tokenizer.hasMoreTokens();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int countTokens() {
|
|
||||||
return tokenizer.countTokens();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String nextToken() {
|
|
||||||
String base = tokenizer.nextToken();
|
|
||||||
if (tokenPreProcess != null)
|
|
||||||
base = tokenPreProcess.preProcess(base);
|
|
||||||
return base;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getTokens() {
|
|
||||||
List<String> tokens = new ArrayList<>();
|
|
||||||
while (hasMoreTokens()) {
|
|
||||||
tokens.add(nextToken());
|
|
||||||
}
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
|
||||||
this.tokenPreProcess = tokenPreProcessor;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,136 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer;
|
|
||||||
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
|
||||||
import org.apache.uima.cas.CAS;
|
|
||||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
|
||||||
import org.apache.uima.fit.util.JCasUtil;
|
|
||||||
import org.cleartk.token.type.Sentence;
|
|
||||||
import org.cleartk.token.type.Token;
|
|
||||||
import org.datavec.nlp.annotator.PoStagger;
|
|
||||||
import org.datavec.nlp.annotator.SentenceAnnotator;
|
|
||||||
import org.datavec.nlp.annotator.StemmerAnnotator;
|
|
||||||
import org.datavec.nlp.annotator.TokenizerAnnotator;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class PosUimaTokenizer implements Tokenizer {
|
|
||||||
|
|
||||||
private static AnalysisEngine engine;
|
|
||||||
private List<String> tokens;
|
|
||||||
private Collection<String> allowedPosTags;
|
|
||||||
private int index;
|
|
||||||
private static CAS cas;
|
|
||||||
|
|
||||||
public PosUimaTokenizer(String tokens, AnalysisEngine engine, Collection<String> allowedPosTags) {
|
|
||||||
if (engine == null)
|
|
||||||
PosUimaTokenizer.engine = engine;
|
|
||||||
this.allowedPosTags = allowedPosTags;
|
|
||||||
this.tokens = new ArrayList<>();
|
|
||||||
try {
|
|
||||||
if (cas == null)
|
|
||||||
cas = engine.newCAS();
|
|
||||||
|
|
||||||
cas.reset();
|
|
||||||
cas.setDocumentText(tokens);
|
|
||||||
PosUimaTokenizer.engine.process(cas);
|
|
||||||
for (Sentence s : JCasUtil.select(cas.getJCas(), Sentence.class)) {
|
|
||||||
for (Token t : JCasUtil.selectCovered(Token.class, s)) {
|
|
||||||
//add NONE for each invalid token
|
|
||||||
if (valid(t))
|
|
||||||
if (t.getLemma() != null)
|
|
||||||
this.tokens.add(t.getLemma());
|
|
||||||
else if (t.getStem() != null)
|
|
||||||
this.tokens.add(t.getStem());
|
|
||||||
else
|
|
||||||
this.tokens.add(t.getCoveredText());
|
|
||||||
else
|
|
||||||
this.tokens.add("NONE");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean valid(Token token) {
|
|
||||||
String check = token.getCoveredText();
|
|
||||||
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
|
|
||||||
return false;
|
|
||||||
else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos()))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasMoreTokens() {
|
|
||||||
return index < tokens.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int countTokens() {
|
|
||||||
return tokens.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String nextToken() {
|
|
||||||
String ret = tokens.get(index);
|
|
||||||
index++;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getTokens() {
|
|
||||||
List<String> tokens = new ArrayList<String>();
|
|
||||||
while (hasMoreTokens()) {
|
|
||||||
tokens.add(nextToken());
|
|
||||||
}
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static AnalysisEngine defaultAnalysisEngine() {
|
|
||||||
try {
|
|
||||||
return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription(
|
|
||||||
SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(),
|
|
||||||
PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English")));
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
|
||||||
// TODO Auto-generated method stub
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,34 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer;
|
|
||||||
|
|
||||||
|
|
||||||
public interface TokenPreProcess {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Pre process a token
|
|
||||||
* @param token the token to pre process
|
|
||||||
* @return the preprocessed token
|
|
||||||
*/
|
|
||||||
String preProcess(String token);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface Tokenizer {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An iterator for tracking whether
|
|
||||||
* more tokens are left in the iterator not
|
|
||||||
* @return whether there is anymore tokens
|
|
||||||
* to iterate over
|
|
||||||
*/
|
|
||||||
boolean hasMoreTokens();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of tokens in the tokenizer
|
|
||||||
* @return the number of tokens
|
|
||||||
*/
|
|
||||||
int countTokens();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The next token (word usually) in the string
|
|
||||||
* @return the next token in the string if any
|
|
||||||
*/
|
|
||||||
String nextToken();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a list of all the tokens
|
|
||||||
* @return a list of all the tokens
|
|
||||||
*/
|
|
||||||
List<String> getTokens();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the token pre process
|
|
||||||
* @param tokenPreProcessor the token pre processor to set
|
|
||||||
*/
|
|
||||||
void setTokenPreProcessor(TokenPreProcess tokenPreProcessor);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,123 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer;
|
|
||||||
|
|
||||||
import org.apache.uima.cas.CAS;
|
|
||||||
import org.apache.uima.fit.util.JCasUtil;
|
|
||||||
import org.cleartk.token.type.Token;
|
|
||||||
import org.datavec.nlp.uima.UimaResource;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tokenizer based on the passed in analysis engine
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class UimaTokenizer implements Tokenizer {
|
|
||||||
|
|
||||||
private List<String> tokens;
|
|
||||||
private int index;
|
|
||||||
private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class);
|
|
||||||
private boolean checkForLabel;
|
|
||||||
private TokenPreProcess tokenPreProcessor;
|
|
||||||
|
|
||||||
|
|
||||||
public UimaTokenizer(String tokens, UimaResource resource, boolean checkForLabel) {
|
|
||||||
|
|
||||||
this.checkForLabel = checkForLabel;
|
|
||||||
this.tokens = new ArrayList<>();
|
|
||||||
try {
|
|
||||||
CAS cas = resource.process(tokens);
|
|
||||||
|
|
||||||
Collection<Token> tokenList = JCasUtil.select(cas.getJCas(), Token.class);
|
|
||||||
|
|
||||||
for (Token t : tokenList) {
|
|
||||||
|
|
||||||
if (!checkForLabel || valid(t.getCoveredText()))
|
|
||||||
if (t.getLemma() != null)
|
|
||||||
this.tokens.add(t.getLemma());
|
|
||||||
else if (t.getStem() != null)
|
|
||||||
this.tokens.add(t.getStem());
|
|
||||||
else
|
|
||||||
this.tokens.add(t.getCoveredText());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
resource.release(cas);
|
|
||||||
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean valid(String check) {
|
|
||||||
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasMoreTokens() {
|
|
||||||
return index < tokens.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int countTokens() {
|
|
||||||
return tokens.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String nextToken() {
|
|
||||||
String ret = tokens.get(index);
|
|
||||||
index++;
|
|
||||||
if (tokenPreProcessor != null) {
|
|
||||||
ret = tokenPreProcessor.preProcess(ret);
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getTokens() {
|
|
||||||
List<String> tokens = new ArrayList<>();
|
|
||||||
while (hasMoreTokens()) {
|
|
||||||
tokens.add(nextToken());
|
|
||||||
}
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
|
|
||||||
this.tokenPreProcessor = tokenPreProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,47 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets rid of endings:
|
|
||||||
*
|
|
||||||
* ed,ing, ly, s, .
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class EndingPreProcessor implements TokenPreProcess {
|
|
||||||
@Override
|
|
||||||
public String preProcess(String token) {
|
|
||||||
if (token.endsWith("s") && !token.endsWith("ss"))
|
|
||||||
token = token.substring(0, token.length() - 1);
|
|
||||||
if (token.endsWith("."))
|
|
||||||
token = token.substring(0, token.length() - 1);
|
|
||||||
if (token.endsWith("ed"))
|
|
||||||
token = token.substring(0, token.length() - 2);
|
|
||||||
if (token.endsWith("ing"))
|
|
||||||
token = token.substring(0, token.length() - 3);
|
|
||||||
if (token.endsWith("ly"))
|
|
||||||
token = token.substring(0, token.length() - 2);
|
|
||||||
return token;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,30 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
|
|
||||||
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
|
|
||||||
public class LowerCasePreProcessor implements TokenPreProcess {
|
|
||||||
@Override
|
|
||||||
public String preProcess(String token) {
|
|
||||||
return token.toLowerCase();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.DefaultTokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Default tokenizer based on string tokenizer or stream tokenizer
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class DefaultTokenizerFactory implements TokenizerFactory {
|
|
||||||
|
|
||||||
private TokenPreProcess tokenPreProcess;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Tokenizer create(String toTokenize) {
|
|
||||||
DefaultTokenizer t = new DefaultTokenizer(toTokenize);
|
|
||||||
t.setTokenPreProcessor(tokenPreProcess);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Tokenizer create(InputStream toTokenize) {
|
|
||||||
Tokenizer t = new DefaultStreamTokenizer(toTokenize);
|
|
||||||
t.setTokenPreProcessor(tokenPreProcess);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
|
|
||||||
this.tokenPreProcess = preProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,85 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
|
||||||
|
|
||||||
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
|
||||||
import org.datavec.nlp.annotator.PoStagger;
|
|
||||||
import org.datavec.nlp.annotator.SentenceAnnotator;
|
|
||||||
import org.datavec.nlp.annotator.StemmerAnnotator;
|
|
||||||
import org.datavec.nlp.annotator.TokenizerAnnotator;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.PosUimaTokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.Collection;
|
|
||||||
|
|
||||||
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine;
|
|
||||||
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
|
|
||||||
|
|
||||||
public class PosUimaTokenizerFactory implements TokenizerFactory {
|
|
||||||
|
|
||||||
private AnalysisEngine tokenizer;
|
|
||||||
private Collection<String> allowedPoSTags;
|
|
||||||
private TokenPreProcess tokenPreProcess;
|
|
||||||
|
|
||||||
|
|
||||||
public PosUimaTokenizerFactory(Collection<String> allowedPoSTags) {
|
|
||||||
this(defaultAnalysisEngine(), allowedPoSTags);
|
|
||||||
}
|
|
||||||
|
|
||||||
public PosUimaTokenizerFactory(AnalysisEngine tokenizer, Collection<String> allowedPosTags) {
|
|
||||||
this.tokenizer = tokenizer;
|
|
||||||
this.allowedPoSTags = allowedPosTags;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static AnalysisEngine defaultAnalysisEngine() {
|
|
||||||
try {
|
|
||||||
return createEngine(createEngineDescription(SentenceAnnotator.getDescription(),
|
|
||||||
TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"),
|
|
||||||
StemmerAnnotator.getDescription("English")));
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Tokenizer create(String toTokenize) {
|
|
||||||
PosUimaTokenizer t = new PosUimaTokenizer(toTokenize, tokenizer, allowedPoSTags);
|
|
||||||
t.setTokenPreProcessor(tokenPreProcess);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Tokenizer create(InputStream toTokenize) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
|
|
||||||
this.tokenPreProcess = preProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generates a tokenizer for a given string
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface TokenizerFactory {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The tokenizer to createComplex
|
|
||||||
* @param toTokenize the string to createComplex the tokenizer with
|
|
||||||
* @return the new tokenizer
|
|
||||||
*/
|
|
||||||
Tokenizer create(String toTokenize);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a tokenizer based on an input stream
|
|
||||||
* @param toTokenize
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Tokenizer create(InputStream toTokenize);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets a token pre processor to be used
|
|
||||||
* with every tokenizer
|
|
||||||
* @param preProcessor the token pre processor to use
|
|
||||||
*/
|
|
||||||
void setTokenPreProcessor(TokenPreProcess preProcessor);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,138 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.tokenization.tokenizerfactory;
|
|
||||||
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
|
||||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
|
||||||
import org.apache.uima.resource.ResourceInitializationException;
|
|
||||||
import org.datavec.nlp.annotator.SentenceAnnotator;
|
|
||||||
import org.datavec.nlp.annotator.TokenizerAnnotator;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.UimaTokenizer;
|
|
||||||
import org.datavec.nlp.uima.UimaResource;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Uses a uima {@link AnalysisEngine} to
|
|
||||||
* tokenize text.
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class UimaTokenizerFactory implements TokenizerFactory {
|
|
||||||
|
|
||||||
|
|
||||||
private UimaResource uimaResource;
|
|
||||||
private boolean checkForLabel;
|
|
||||||
private static AnalysisEngine defaultAnalysisEngine;
|
|
||||||
private TokenPreProcess preProcess;
|
|
||||||
|
|
||||||
public UimaTokenizerFactory() throws ResourceInitializationException {
|
|
||||||
this(defaultAnalysisEngine(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public UimaTokenizerFactory(UimaResource resource) {
|
|
||||||
this(resource, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public UimaTokenizerFactory(AnalysisEngine tokenizer) {
|
|
||||||
this(tokenizer, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public UimaTokenizerFactory(UimaResource resource, boolean checkForLabel) {
|
|
||||||
this.uimaResource = resource;
|
|
||||||
this.checkForLabel = checkForLabel;
|
|
||||||
}
|
|
||||||
|
|
||||||
public UimaTokenizerFactory(boolean checkForLabel) throws ResourceInitializationException {
|
|
||||||
this(defaultAnalysisEngine(), checkForLabel);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public UimaTokenizerFactory(AnalysisEngine tokenizer, boolean checkForLabel) {
|
|
||||||
super();
|
|
||||||
this.checkForLabel = checkForLabel;
|
|
||||||
try {
|
|
||||||
this.uimaResource = new UimaResource(tokenizer);
|
|
||||||
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Tokenizer create(String toTokenize) {
|
|
||||||
if (toTokenize == null || toTokenize.isEmpty())
|
|
||||||
throw new IllegalArgumentException("Unable to proceed; on sentence to tokenize");
|
|
||||||
Tokenizer ret = new UimaTokenizer(toTokenize, uimaResource, checkForLabel);
|
|
||||||
ret.setTokenPreProcessor(preProcess);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public UimaResource getUimaResource() {
|
|
||||||
return uimaResource;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a tokenization,/stemming pipeline
|
|
||||||
* @return a tokenization/stemming pipeline
|
|
||||||
*/
|
|
||||||
public static AnalysisEngine defaultAnalysisEngine() {
|
|
||||||
try {
|
|
||||||
if (defaultAnalysisEngine == null)
|
|
||||||
|
|
||||||
defaultAnalysisEngine = AnalysisEngineFactory.createEngine(
|
|
||||||
AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.getDescription(),
|
|
||||||
TokenizerAnnotator.getDescription()));
|
|
||||||
|
|
||||||
return defaultAnalysisEngine;
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Tokenizer create(InputStream toTokenize) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
|
|
||||||
this.preProcess = preProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.Transform;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface BagOfWordsTransform extends Transform {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The output shape of the transform (usually 1 x number of words)
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
long[] outputShape();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The vocab words in the transform.
|
|
||||||
* This is the words that were accumulated
|
|
||||||
* when building a vocabulary.
|
|
||||||
* (This is generally associated with some form of
|
|
||||||
* mininmum words frequency scanning to build a vocab
|
|
||||||
* you then map on to a list of vocab words as a list)
|
|
||||||
* @return the vocab words for the transform
|
|
||||||
*/
|
|
||||||
List<String> vocabWords();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform for a list of tokens
|
|
||||||
* that are objects. This is to allow loose
|
|
||||||
* typing for tokens that are unique (non string)
|
|
||||||
* @param tokens the token objects to transform
|
|
||||||
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
|
|
||||||
*/
|
|
||||||
INDArray transformFromObject(List<List<Object>> tokens);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform for a list of tokens
|
|
||||||
* that are {@link Writable} (Generally {@link org.datavec.api.writable.Text}
|
|
||||||
* @param tokens the token objects to transform
|
|
||||||
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
|
|
||||||
*/
|
|
||||||
INDArray transformFrom(List<List<Writable>> tokens);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
public class BaseWordMapTransform {
|
|
||||||
}
|
|
|
@ -1,146 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
|
||||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
|
||||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
@JsonIgnoreProperties({"gazeteer"})
|
|
||||||
public class GazeteerTransform extends BaseColumnTransform implements BagOfWordsTransform {
|
|
||||||
|
|
||||||
private String newColumnName;
|
|
||||||
private List<String> wordList;
|
|
||||||
private Set<String> gazeteer;
|
|
||||||
|
|
||||||
@JsonCreator
|
|
||||||
public GazeteerTransform(@JsonProperty("columnName") String columnName,
|
|
||||||
@JsonProperty("newColumnName")String newColumnName,
|
|
||||||
@JsonProperty("wordList") List<String> wordList) {
|
|
||||||
super(columnName);
|
|
||||||
this.newColumnName = newColumnName;
|
|
||||||
this.wordList = wordList;
|
|
||||||
this.gazeteer = new HashSet<>(wordList);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
|
|
||||||
return new NDArrayMetaData(newName,new long[]{wordList.size()});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Writable map(Writable columnWritable) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object mapSequence(Object sequence) {
|
|
||||||
List<List<Object>> sequenceInput = (List<List<Object>>) sequence;
|
|
||||||
INDArray ret = Nd4j.create(DataType.FLOAT, wordList.size());
|
|
||||||
|
|
||||||
for(List<Object> list : sequenceInput) {
|
|
||||||
for(Object token : list) {
|
|
||||||
String s = token.toString();
|
|
||||||
if(gazeteer.contains(s)) {
|
|
||||||
ret.putScalar(wordList.indexOf(s),1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
|
|
||||||
INDArray arr = (INDArray) mapSequence((Object) sequence);
|
|
||||||
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(arr)));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return newColumnName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object map(Object input) {
|
|
||||||
return gazeteer.contains(input.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String outputColumnName() {
|
|
||||||
return newColumnName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] outputColumnNames() {
|
|
||||||
return new String[]{newColumnName};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] columnNames() {
|
|
||||||
return new String[]{columnName()};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String columnName() {
|
|
||||||
return columnName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] outputShape() {
|
|
||||||
return new long[]{wordList.size()};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> vocabWords() {
|
|
||||||
return wordList;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray transformFromObject(List<List<Object>> tokens) {
|
|
||||||
return (INDArray) mapSequence(tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray transformFrom(List<List<Writable>> tokens) {
|
|
||||||
return (INDArray) mapSequence((Object) tokens);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,150 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
|
||||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
|
||||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.list.NDArrayList;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform {
|
|
||||||
|
|
||||||
private BagOfWordsTransform[] transforms;
|
|
||||||
private String newColumnName;
|
|
||||||
private List<String> vocabWords;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param columnName
|
|
||||||
* @param transforms
|
|
||||||
* @param newColumnName
|
|
||||||
*/
|
|
||||||
@JsonCreator
|
|
||||||
public MultiNlpTransform(@JsonProperty("columnName") String columnName,
|
|
||||||
@JsonProperty("transforms") BagOfWordsTransform[] transforms,
|
|
||||||
@JsonProperty("newColumnName") String newColumnName) {
|
|
||||||
super(columnName);
|
|
||||||
this.transforms = transforms;
|
|
||||||
this.vocabWords = transforms[0].vocabWords();
|
|
||||||
if(transforms.length > 1) {
|
|
||||||
for(int i = 1; i < transforms.length; i++) {
|
|
||||||
if(!transforms[i].vocabWords().equals(vocabWords)) {
|
|
||||||
throw new IllegalArgumentException("Vocab words not consistent across transforms!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.newColumnName = newColumnName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object mapSequence(Object sequence) {
|
|
||||||
NDArrayList ndArrayList = new NDArrayList();
|
|
||||||
for(BagOfWordsTransform bagofWordsTransform : transforms) {
|
|
||||||
ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List<List<Object>>) sequence)));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ndArrayList.array();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
|
|
||||||
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(transformFrom(sequence))));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
|
|
||||||
return new NDArrayMetaData(newName,outputShape());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Writable map(Writable columnWritable) {
|
|
||||||
throw new UnsupportedOperationException("Only able to add for time series");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return newColumnName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object map(Object input) {
|
|
||||||
throw new UnsupportedOperationException("Only able to add for time series");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] outputShape() {
|
|
||||||
long[] ret = new long[transforms[0].outputShape().length];
|
|
||||||
int validatedRank = transforms[0].outputShape().length;
|
|
||||||
for(int i = 1; i < transforms.length; i++) {
|
|
||||||
if(transforms[i].outputShape().length != validatedRank) {
|
|
||||||
throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for(int i = 0; i < transforms.length; i++) {
|
|
||||||
for(int j = 0; j < validatedRank; j++)
|
|
||||||
ret[j] += transforms[i].outputShape()[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> vocabWords() {
|
|
||||||
return vocabWords;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray transformFromObject(List<List<Object>> tokens) {
|
|
||||||
NDArrayList ndArrayList = new NDArrayList();
|
|
||||||
for(BagOfWordsTransform bagofWordsTransform : transforms) {
|
|
||||||
INDArray arr2 = bagofWordsTransform.transformFromObject(tokens);
|
|
||||||
arr2 = arr2.reshape(arr2.length());
|
|
||||||
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
|
|
||||||
ndArrayList.addAll(newList); }
|
|
||||||
|
|
||||||
return ndArrayList.array();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray transformFrom(List<List<Writable>> tokens) {
|
|
||||||
NDArrayList ndArrayList = new NDArrayList();
|
|
||||||
for(BagOfWordsTransform bagofWordsTransform : transforms) {
|
|
||||||
INDArray arr2 = bagofWordsTransform.transformFrom(tokens);
|
|
||||||
arr2 = arr2.reshape(arr2.length());
|
|
||||||
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
|
|
||||||
ndArrayList.addAll(newList);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ndArrayList.array();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,226 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
|
||||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Counter;
|
|
||||||
import org.nd4j.common.util.MathUtils;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@EqualsAndHashCode(callSuper = true, exclude = {"tokenizerFactory"})
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform {
|
|
||||||
|
|
||||||
private String newColumName;
|
|
||||||
private Map<String,Integer> wordIndexMap;
|
|
||||||
private Map<String,Double> weightMap;
|
|
||||||
private boolean exceptionOnUnknown;
|
|
||||||
private String tokenizerFactoryClass;
|
|
||||||
private String preprocessorClass;
|
|
||||||
private TokenizerFactory tokenizerFactory;
|
|
||||||
|
|
||||||
@JsonCreator
|
|
||||||
public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String columnName,
|
|
||||||
@JsonProperty("newColumnName") String newColumnName,
|
|
||||||
@JsonProperty("wordIndexMap") Map<String,Integer> wordIndexMap,
|
|
||||||
@JsonProperty("idfMap") Map<String,Double> idfMap,
|
|
||||||
@JsonProperty("exceptionOnUnknown") boolean exceptionOnUnknown,
|
|
||||||
@JsonProperty("tokenizerFactoryClass") String tokenizerFactoryClass,
|
|
||||||
@JsonProperty("preprocessorClass") String preprocessorClass) {
|
|
||||||
super(columnName);
|
|
||||||
this.newColumName = newColumnName;
|
|
||||||
this.wordIndexMap = wordIndexMap;
|
|
||||||
this.exceptionOnUnknown = exceptionOnUnknown;
|
|
||||||
this.weightMap = idfMap;
|
|
||||||
this.tokenizerFactoryClass = tokenizerFactoryClass;
|
|
||||||
this.preprocessorClass = preprocessorClass;
|
|
||||||
if(this.tokenizerFactoryClass == null) {
|
|
||||||
this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName();
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
|
|
||||||
}
|
|
||||||
|
|
||||||
if(preprocessorClass != null){
|
|
||||||
try {
|
|
||||||
TokenPreProcess tpp = (TokenPreProcess) Class.forName(this.preprocessorClass).newInstance();
|
|
||||||
tokenizerFactory.setTokenPreProcessor(tpp);
|
|
||||||
} catch (Exception e){
|
|
||||||
throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Writable> map(List<Writable> writables) {
|
|
||||||
Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName));
|
|
||||||
List<Writable> ret = new ArrayList<>(writables);
|
|
||||||
ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString())));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object map(Object input) {
|
|
||||||
return convert(input.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object mapSequence(Object sequence) {
|
|
||||||
return convert(sequence.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Schema transform(Schema inputSchema) {
|
|
||||||
Schema.Builder newSchema = new Schema.Builder();
|
|
||||||
for(int i = 0; i < inputSchema.numColumns(); i++) {
|
|
||||||
if(inputSchema.getName(i).equals(this.columnName)) {
|
|
||||||
newSchema.addColumnNDArray(newColumName,new long[]{1,wordIndexMap.size()});
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
newSchema.addColumn(inputSchema.getMetaData(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newSchema.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert the given text
|
|
||||||
* in to an {@link INDArray}
|
|
||||||
* using the {@link TokenizerFactory}
|
|
||||||
* specified in the constructor.
|
|
||||||
* @param text the text to transform
|
|
||||||
* @return the created {@link INDArray}
|
|
||||||
* based on the {@link #wordIndexMap} for the column indices
|
|
||||||
* of the word.
|
|
||||||
*/
|
|
||||||
public INDArray convert(String text) {
|
|
||||||
Tokenizer tokenizer = tokenizerFactory.create(text);
|
|
||||||
List<String> tokens = tokenizer.getTokens();
|
|
||||||
INDArray create = Nd4j.create(1,wordIndexMap.size());
|
|
||||||
Counter<String> tokenizedCounter = new Counter<>();
|
|
||||||
|
|
||||||
for(int i = 0; i < tokens.size(); i++) {
|
|
||||||
tokenizedCounter.incrementCount(tokens.get(i),1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int i = 0; i < tokens.size(); i++) {
|
|
||||||
if(wordIndexMap.containsKey(tokens.get(i))) {
|
|
||||||
int idx = wordIndexMap.get(tokens.get(i));
|
|
||||||
int count = (int) tokenizedCounter.getCount(tokens.get(i));
|
|
||||||
double weight = tfidfWord(tokens.get(i),count,tokens.size());
|
|
||||||
create.putScalar(idx,weight);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return create;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the tifdf for a word
|
|
||||||
* given the word, word count, and document length
|
|
||||||
* @param word the word to calculate
|
|
||||||
* @param wordCount the word frequency
|
|
||||||
* @param documentLength the number of words in the document
|
|
||||||
* @return the tfidf weight for a given word
|
|
||||||
*/
|
|
||||||
public double tfidfWord(String word, long wordCount, long documentLength) {
|
|
||||||
double tf = tfForWord(wordCount, documentLength);
|
|
||||||
double idf = idfForWord(word);
|
|
||||||
return MathUtils.tfidf(tf, idf);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the weight term frequency for a given
|
|
||||||
* word normalized by the dcoument length
|
|
||||||
* @param wordCount the word frequency
|
|
||||||
* @param documentLength the number of words in the edocument
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private double tfForWord(long wordCount, long documentLength) {
|
|
||||||
return wordCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
private double idfForWord(String word) {
|
|
||||||
if(weightMap.containsKey(word))
|
|
||||||
return weightMap.get(word);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
|
|
||||||
return new NDArrayMetaData(outputColumnName(),new long[]{1,wordIndexMap.size()});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String outputColumnName() {
|
|
||||||
return newColumName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] outputColumnNames() {
|
|
||||||
return new String[]{newColumName};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] columnNames() {
|
|
||||||
return new String[]{columnName()};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String columnName() {
|
|
||||||
return columnName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Writable map(Writable columnWritable) {
|
|
||||||
return new NDArrayWritable(convert(columnWritable.toString()));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,107 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.uima;
|
|
||||||
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
|
||||||
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
|
|
||||||
import org.apache.uima.cas.CAS;
|
|
||||||
import org.apache.uima.resource.ResourceInitializationException;
|
|
||||||
import org.apache.uima.util.CasPool;
|
|
||||||
|
|
||||||
public class UimaResource {
|
|
||||||
|
|
||||||
private AnalysisEngine analysisEngine;
|
|
||||||
private CasPool casPool;
|
|
||||||
|
|
||||||
public UimaResource(AnalysisEngine analysisEngine) throws ResourceInitializationException {
|
|
||||||
this.analysisEngine = analysisEngine;
|
|
||||||
this.casPool = new CasPool(Runtime.getRuntime().availableProcessors() * 10, analysisEngine);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public UimaResource(AnalysisEngine analysisEngine, CasPool casPool) {
|
|
||||||
this.analysisEngine = analysisEngine;
|
|
||||||
this.casPool = casPool;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public AnalysisEngine getAnalysisEngine() {
|
|
||||||
return analysisEngine;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public void setAnalysisEngine(AnalysisEngine analysisEngine) {
|
|
||||||
this.analysisEngine = analysisEngine;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public CasPool getCasPool() {
|
|
||||||
return casPool;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public void setCasPool(CasPool casPool) {
|
|
||||||
this.casPool = casPool;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use the given analysis engine and process the given text
|
|
||||||
* You must release the return cas yourself
|
|
||||||
* @param text the text to rpocess
|
|
||||||
* @return the processed cas
|
|
||||||
*/
|
|
||||||
public CAS process(String text) {
|
|
||||||
CAS cas = retrieve();
|
|
||||||
|
|
||||||
cas.setDocumentText(text);
|
|
||||||
try {
|
|
||||||
analysisEngine.process(cas);
|
|
||||||
} catch (AnalysisEngineProcessException e) {
|
|
||||||
if (text != null && !text.isEmpty())
|
|
||||||
return process(text);
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return cas;
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public CAS retrieve() {
|
|
||||||
CAS ret = casPool.getCas();
|
|
||||||
try {
|
|
||||||
return ret == null ? analysisEngine.newCAS() : ret;
|
|
||||||
} catch (ResourceInitializationException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public void release(CAS cas) {
|
|
||||||
casPool.releaseCas(cas);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,77 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.vectorizer;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.Record;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
|
||||||
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
public abstract class AbstractTfidfVectorizer<VECTOR_TYPE> extends TextVectorizer<VECTOR_TYPE> {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void doWithTokens(Tokenizer tokenizer) {
|
|
||||||
Set<String> seen = new HashSet<>();
|
|
||||||
while (tokenizer.hasMoreTokens()) {
|
|
||||||
String token = tokenizer.nextToken();
|
|
||||||
if (!stopWords.contains(token)) {
|
|
||||||
cache.incrementCount(token);
|
|
||||||
if (!seen.contains(token)) {
|
|
||||||
cache.incrementDocCount(token);
|
|
||||||
}
|
|
||||||
seen.add(token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public TokenizerFactory createTokenizerFactory(Configuration conf) {
|
|
||||||
String clazz = conf.get(TOKENIZER, DefaultTokenizerFactory.class.getName());
|
|
||||||
try {
|
|
||||||
Class<? extends TokenizerFactory> tokenizerFactoryClazz =
|
|
||||||
(Class<? extends TokenizerFactory>) Class.forName(clazz);
|
|
||||||
TokenizerFactory tf = tokenizerFactoryClazz.newInstance();
|
|
||||||
String preproc = conf.get(PREPROCESSOR, null);
|
|
||||||
if(preproc != null){
|
|
||||||
TokenPreProcess tpp = (TokenPreProcess) Class.forName(preproc).newInstance();
|
|
||||||
tf.setTokenPreProcessor(tpp);
|
|
||||||
}
|
|
||||||
return tf;
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public abstract VECTOR_TYPE createVector(Object[] args);
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public abstract VECTOR_TYPE fitTransform(RecordReader reader);
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public abstract VECTOR_TYPE transform(Record record);
|
|
||||||
}
|
|
|
@ -1,121 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.vectorizer;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.nd4j.common.primitives.Counter;
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.Record;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.vector.Vectorizer;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.nlp.metadata.DefaultVocabCache;
|
|
||||||
import org.datavec.nlp.metadata.VocabCache;
|
|
||||||
import org.datavec.nlp.stopwords.StopWords;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
|
|
||||||
|
|
||||||
import java.util.Collection;
|
|
||||||
|
|
||||||
public abstract class TextVectorizer<VECTOR_TYPE> implements Vectorizer<VECTOR_TYPE> {
|
|
||||||
|
|
||||||
protected TokenizerFactory tokenizerFactory;
|
|
||||||
protected int minWordFrequency = 0;
|
|
||||||
public final static String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency";
|
|
||||||
public final static String STOP_WORDS = "org.nd4j.nlp.stopwords";
|
|
||||||
public final static String TOKENIZER = "org.datavec.nlp.tokenizerfactory";
|
|
||||||
public static final String PREPROCESSOR = "org.datavec.nlp.preprocessor";
|
|
||||||
public final static String VOCAB_CACHE = "org.datavec.nlp.vocabcache";
|
|
||||||
protected Collection<String> stopWords;
|
|
||||||
@Getter
|
|
||||||
protected VocabCache cache;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(Configuration conf) {
|
|
||||||
tokenizerFactory = createTokenizerFactory(conf);
|
|
||||||
minWordFrequency = conf.getInt(MIN_WORD_FREQUENCY, 5);
|
|
||||||
if(conf.get(STOP_WORDS) != null)
|
|
||||||
stopWords = conf.getStringCollection(STOP_WORDS);
|
|
||||||
if (stopWords == null)
|
|
||||||
stopWords = StopWords.getStopWords();
|
|
||||||
|
|
||||||
String clazz = conf.get(VOCAB_CACHE, DefaultVocabCache.class.getName());
|
|
||||||
try {
|
|
||||||
Class<? extends VocabCache> tokenizerFactoryClazz = (Class<? extends VocabCache>) Class.forName(clazz);
|
|
||||||
cache = tokenizerFactoryClazz.newInstance();
|
|
||||||
cache.initialize(conf);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void fit(RecordReader reader) {
|
|
||||||
fit(reader, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void fit(RecordReader reader, RecordCallBack callBack) {
|
|
||||||
while (reader.hasNext()) {
|
|
||||||
Record record = reader.nextRecord();
|
|
||||||
String s = toString(record.getRecord());
|
|
||||||
Tokenizer tokenizer = tokenizerFactory.create(s);
|
|
||||||
doWithTokens(tokenizer);
|
|
||||||
if (callBack != null)
|
|
||||||
callBack.onRecord(record);
|
|
||||||
cache.incrementNumDocs(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected Counter<String> wordFrequenciesForRecord(Collection<Writable> record) {
|
|
||||||
String s = toString(record);
|
|
||||||
Tokenizer tokenizer = tokenizerFactory.create(s);
|
|
||||||
Counter<String> ret = new Counter<>();
|
|
||||||
while (tokenizer.hasMoreTokens())
|
|
||||||
ret.incrementCount(tokenizer.nextToken(), 1.0);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected String toString(Collection<Writable> record) {
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
for(Writable w : record){
|
|
||||||
sb.append(w.toString());
|
|
||||||
}
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increment counts, add to collection,...
|
|
||||||
* @param tokenizer
|
|
||||||
*/
|
|
||||||
public abstract void doWithTokens(Tokenizer tokenizer);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create tokenizer factory based on the configuration
|
|
||||||
* @param conf the configuration to use
|
|
||||||
* @return the tokenizer factory based on the configuration
|
|
||||||
*/
|
|
||||||
public abstract TokenizerFactory createTokenizerFactory(Configuration conf);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,105 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.vectorizer;
|
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.nd4j.common.primitives.Counter;
|
|
||||||
import org.datavec.api.records.Record;
|
|
||||||
import org.datavec.api.records.metadata.RecordMetaDataURI;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class TfidfVectorizer extends AbstractTfidfVectorizer<INDArray> {
|
|
||||||
/**
|
|
||||||
* Default: True.<br>
|
|
||||||
* If true: use idf(d, t) = log [ (1 + n) / (1 + df(d, t)) ] + 1<br>
|
|
||||||
* If false: use idf(t) = log [ n / df(t) ] + 1<br>
|
|
||||||
*/
|
|
||||||
public static final String SMOOTH_IDF = "org.datavec.nlp.TfidfVectorizer.smooth_idf";
|
|
||||||
|
|
||||||
protected boolean smooth_idf;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createVector(Object[] args) {
|
|
||||||
Counter<String> docFrequencies = (Counter<String>) args[0];
|
|
||||||
double[] vector = new double[cache.vocabWords().size()];
|
|
||||||
for (int i = 0; i < cache.vocabWords().size(); i++) {
|
|
||||||
String word = cache.wordAt(i);
|
|
||||||
double freq = docFrequencies.getCount(word);
|
|
||||||
vector[i] = cache.tfidf(word, freq, smooth_idf);
|
|
||||||
}
|
|
||||||
return Nd4j.create(vector);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray fitTransform(RecordReader reader) {
|
|
||||||
return fitTransform(reader, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray fitTransform(final RecordReader reader, RecordCallBack callBack) {
|
|
||||||
final List<Record> records = new ArrayList<>();
|
|
||||||
fit(reader, new RecordCallBack() {
|
|
||||||
@Override
|
|
||||||
public void onRecord(Record record) {
|
|
||||||
records.add(record);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (records.isEmpty())
|
|
||||||
throw new IllegalStateException("No records found!");
|
|
||||||
INDArray ret = Nd4j.create(records.size(), cache.vocabWords().size());
|
|
||||||
int i = 0;
|
|
||||||
for (Record record : records) {
|
|
||||||
INDArray transformed = transform(record);
|
|
||||||
org.datavec.api.records.impl.Record transformedRecord = new org.datavec.api.records.impl.Record(
|
|
||||||
Arrays.asList(new NDArrayWritable(transformed),
|
|
||||||
record.getRecord().get(record.getRecord().size() - 1)),
|
|
||||||
new RecordMetaDataURI(record.getMetaData().getURI(), reader.getClass()));
|
|
||||||
ret.putRow(i++, transformed);
|
|
||||||
if (callBack != null) {
|
|
||||||
callBack.onRecord(transformedRecord);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray transform(Record record) {
|
|
||||||
Counter<String> wordFrequencies = wordFrequenciesForRecord(record.getRecord());
|
|
||||||
return createVector(new Object[] {wordFrequencies});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(Configuration conf){
|
|
||||||
super.initialize(conf);
|
|
||||||
this.smooth_idf = conf.getBoolean(SMOOTH_IDF, true);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,194 +0,0 @@
|
||||||
a
|
|
||||||
----s
|
|
||||||
act
|
|
||||||
"the
|
|
||||||
"The
|
|
||||||
about
|
|
||||||
above
|
|
||||||
after
|
|
||||||
again
|
|
||||||
against
|
|
||||||
all
|
|
||||||
am
|
|
||||||
an
|
|
||||||
and
|
|
||||||
any
|
|
||||||
are
|
|
||||||
aren't
|
|
||||||
as
|
|
||||||
at
|
|
||||||
be
|
|
||||||
because
|
|
||||||
been
|
|
||||||
before
|
|
||||||
being
|
|
||||||
below
|
|
||||||
between
|
|
||||||
both
|
|
||||||
but
|
|
||||||
by
|
|
||||||
can't
|
|
||||||
cannot
|
|
||||||
could
|
|
||||||
couldn't
|
|
||||||
did
|
|
||||||
didn't
|
|
||||||
do
|
|
||||||
does
|
|
||||||
doesn't
|
|
||||||
doing
|
|
||||||
don't
|
|
||||||
down
|
|
||||||
during
|
|
||||||
each
|
|
||||||
few
|
|
||||||
for
|
|
||||||
from
|
|
||||||
further
|
|
||||||
had
|
|
||||||
hadn't
|
|
||||||
has
|
|
||||||
hasn't
|
|
||||||
have
|
|
||||||
haven't
|
|
||||||
having
|
|
||||||
he
|
|
||||||
he'd
|
|
||||||
he'll
|
|
||||||
he's
|
|
||||||
her
|
|
||||||
here
|
|
||||||
here's
|
|
||||||
hers
|
|
||||||
herself
|
|
||||||
him
|
|
||||||
himself
|
|
||||||
his
|
|
||||||
how
|
|
||||||
how's
|
|
||||||
i
|
|
||||||
i'd
|
|
||||||
i'll
|
|
||||||
i'm
|
|
||||||
i've
|
|
||||||
if
|
|
||||||
in
|
|
||||||
into
|
|
||||||
is
|
|
||||||
isn't
|
|
||||||
it
|
|
||||||
it's
|
|
||||||
its
|
|
||||||
itself
|
|
||||||
let's
|
|
||||||
me
|
|
||||||
more
|
|
||||||
most
|
|
||||||
mustn't
|
|
||||||
my
|
|
||||||
myself
|
|
||||||
no
|
|
||||||
nor
|
|
||||||
not
|
|
||||||
of
|
|
||||||
off
|
|
||||||
on
|
|
||||||
once
|
|
||||||
only
|
|
||||||
or
|
|
||||||
other
|
|
||||||
ought
|
|
||||||
our
|
|
||||||
ours
|
|
||||||
ourselves
|
|
||||||
out
|
|
||||||
over
|
|
||||||
own
|
|
||||||
put
|
|
||||||
same
|
|
||||||
shan't
|
|
||||||
she
|
|
||||||
she'd
|
|
||||||
she'll
|
|
||||||
she's
|
|
||||||
should
|
|
||||||
somebody
|
|
||||||
something
|
|
||||||
shouldn't
|
|
||||||
so
|
|
||||||
some
|
|
||||||
such
|
|
||||||
take
|
|
||||||
than
|
|
||||||
that
|
|
||||||
that's
|
|
||||||
the
|
|
||||||
their
|
|
||||||
theirs
|
|
||||||
them
|
|
||||||
themselves
|
|
||||||
then
|
|
||||||
there
|
|
||||||
there's
|
|
||||||
these
|
|
||||||
they
|
|
||||||
they'd
|
|
||||||
they'll
|
|
||||||
they're
|
|
||||||
they've
|
|
||||||
this
|
|
||||||
those
|
|
||||||
through
|
|
||||||
to
|
|
||||||
too
|
|
||||||
under
|
|
||||||
until
|
|
||||||
up
|
|
||||||
very
|
|
||||||
was
|
|
||||||
wasn't
|
|
||||||
we
|
|
||||||
we'd
|
|
||||||
we'll
|
|
||||||
we're
|
|
||||||
we've
|
|
||||||
were
|
|
||||||
weren't
|
|
||||||
what
|
|
||||||
what's
|
|
||||||
when
|
|
||||||
when's
|
|
||||||
where
|
|
||||||
where's
|
|
||||||
which
|
|
||||||
while
|
|
||||||
who
|
|
||||||
who's
|
|
||||||
whom
|
|
||||||
why
|
|
||||||
why's
|
|
||||||
will
|
|
||||||
with
|
|
||||||
without
|
|
||||||
won't
|
|
||||||
would
|
|
||||||
wouldn't
|
|
||||||
you
|
|
||||||
you'd
|
|
||||||
you'll
|
|
||||||
you're
|
|
||||||
you've
|
|
||||||
your
|
|
||||||
yours
|
|
||||||
yourself
|
|
||||||
yourselves
|
|
||||||
.
|
|
||||||
?
|
|
||||||
!
|
|
||||||
,
|
|
||||||
+
|
|
||||||
=
|
|
||||||
also
|
|
||||||
-
|
|
||||||
;
|
|
||||||
:
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
package org.datavec.nlp;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.datavec.nlp";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseND4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,130 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.reader;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.Record;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.CollectionInputSplit;
|
|
||||||
import org.datavec.api.split.FileSplit;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.nlp.vectorizer.TfidfVectorizer;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class TfidfRecordReaderTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testReader() throws Exception {
|
|
||||||
TfidfVectorizer vectorizer = new TfidfVectorizer();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1);
|
|
||||||
conf.setBoolean(RecordReader.APPEND_LABEL, true);
|
|
||||||
vectorizer.initialize(conf);
|
|
||||||
TfidfRecordReader reader = new TfidfRecordReader();
|
|
||||||
File f = testDir.newFolder();
|
|
||||||
new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f);
|
|
||||||
List<URI> u = new ArrayList<>();
|
|
||||||
for(File f2 : f.listFiles()){
|
|
||||||
if(f2.isDirectory()){
|
|
||||||
for(File f3 : f2.listFiles()){
|
|
||||||
u.add(f3.toURI());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
u.add(f2.toURI());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Collections.sort(u);
|
|
||||||
CollectionInputSplit c = new CollectionInputSplit(u);
|
|
||||||
reader.initialize(conf, c);
|
|
||||||
int count = 0;
|
|
||||||
int[] labelAssertions = new int[3];
|
|
||||||
while (reader.hasNext()) {
|
|
||||||
Collection<Writable> record = reader.next();
|
|
||||||
Iterator<Writable> recordIter = record.iterator();
|
|
||||||
NDArrayWritable writable = (NDArrayWritable) recordIter.next();
|
|
||||||
labelAssertions[count] = recordIter.next().toInt();
|
|
||||||
count++;
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(new int[] {0, 1, 2}, labelAssertions);
|
|
||||||
assertEquals(3, reader.getLabels().size());
|
|
||||||
assertEquals(3, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testRecordMetaData() throws Exception {
|
|
||||||
TfidfVectorizer vectorizer = new TfidfVectorizer();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1);
|
|
||||||
conf.setBoolean(RecordReader.APPEND_LABEL, true);
|
|
||||||
vectorizer.initialize(conf);
|
|
||||||
TfidfRecordReader reader = new TfidfRecordReader();
|
|
||||||
File f = testDir.newFolder();
|
|
||||||
new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f);
|
|
||||||
reader.initialize(conf, new FileSplit(f));
|
|
||||||
|
|
||||||
while (reader.hasNext()) {
|
|
||||||
Record record = reader.nextRecord();
|
|
||||||
assertNotNull(record.getMetaData().getURI());
|
|
||||||
assertEquals(record.getMetaData().getReaderClass(), TfidfRecordReader.class);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testReadRecordFromMetaData() throws Exception {
|
|
||||||
TfidfVectorizer vectorizer = new TfidfVectorizer();
|
|
||||||
Configuration conf = new Configuration();
|
|
||||||
conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1);
|
|
||||||
conf.setBoolean(RecordReader.APPEND_LABEL, true);
|
|
||||||
vectorizer.initialize(conf);
|
|
||||||
TfidfRecordReader reader = new TfidfRecordReader();
|
|
||||||
File f = testDir.newFolder();
|
|
||||||
new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f);
|
|
||||||
reader.initialize(conf, new FileSplit(f));
|
|
||||||
|
|
||||||
Record record = reader.nextRecord();
|
|
||||||
|
|
||||||
Record reread = reader.loadFromMetaData(record.getMetaData());
|
|
||||||
|
|
||||||
assertEquals(record.getRecord().size(), 2);
|
|
||||||
assertEquals(reread.getRecord().size(), 2);
|
|
||||||
assertEquals(record.getRecord().get(0), reread.getRecord().get(0));
|
|
||||||
assertEquals(record.getRecord().get(1), reread.getRecord().get(1));
|
|
||||||
assertEquals(record.getMetaData(), reread.getMetaData());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,96 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class TestGazeteerTransform {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGazeteerTransform(){
|
|
||||||
|
|
||||||
String[] corpus = {
|
|
||||||
"hello I like apple".toLowerCase(),
|
|
||||||
"cherry date eggplant potato".toLowerCase()
|
|
||||||
};
|
|
||||||
|
|
||||||
//Gazeteer transform: basically 0/1 if word is present. Assumes already tokenized input
|
|
||||||
List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant");
|
|
||||||
|
|
||||||
GazeteerTransform t = new GazeteerTransform("words", "out", words);
|
|
||||||
|
|
||||||
SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder()
|
|
||||||
.addColumnString("words").build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(schema)
|
|
||||||
.transform(t)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<List<Writable>>> input = new ArrayList<>();
|
|
||||||
for(String s : corpus){
|
|
||||||
String[] split = s.split(" ");
|
|
||||||
List<List<Writable>> seq = new ArrayList<>();
|
|
||||||
for(String s2 : split){
|
|
||||||
seq.add(Collections.<Writable>singletonList(new Text(s2)));
|
|
||||||
}
|
|
||||||
input.add(seq);
|
|
||||||
}
|
|
||||||
|
|
||||||
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp);
|
|
||||||
|
|
||||||
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
|
|
||||||
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0});
|
|
||||||
INDArray exp1 = Nd4j.create(new float[]{0, 0, 1, 1, 1});
|
|
||||||
|
|
||||||
assertEquals(exp0, arr0);
|
|
||||||
assertEquals(exp1, arr1);
|
|
||||||
|
|
||||||
|
|
||||||
String json = tp.toJson();
|
|
||||||
TransformProcess tp2 = TransformProcess.fromJson(json);
|
|
||||||
assertEquals(tp, tp2);
|
|
||||||
|
|
||||||
List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp);
|
|
||||||
INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get();
|
|
||||||
INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
assertEquals(exp0, arr0a);
|
|
||||||
assertEquals(exp1, arr1a);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,96 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class TestMultiNLPTransform {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test(){
|
|
||||||
|
|
||||||
List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant");
|
|
||||||
GazeteerTransform t1 = new GazeteerTransform("words", "out", words);
|
|
||||||
GazeteerTransform t2 = new GazeteerTransform("out", "out", words);
|
|
||||||
|
|
||||||
|
|
||||||
MultiNlpTransform multi = new MultiNlpTransform("text", new BagOfWordsTransform[]{t1, t2}, "out");
|
|
||||||
|
|
||||||
String[] corpus = {
|
|
||||||
"hello I like apple".toLowerCase(),
|
|
||||||
"date eggplant potato".toLowerCase()
|
|
||||||
};
|
|
||||||
|
|
||||||
List<List<List<Writable>>> input = new ArrayList<>();
|
|
||||||
for(String s : corpus){
|
|
||||||
String[] split = s.split(" ");
|
|
||||||
List<List<Writable>> seq = new ArrayList<>();
|
|
||||||
for(String s2 : split){
|
|
||||||
seq.add(Collections.<Writable>singletonList(new Text(s2)));
|
|
||||||
}
|
|
||||||
input.add(seq);
|
|
||||||
}
|
|
||||||
|
|
||||||
SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder()
|
|
||||||
.addColumnString("text").build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(schema)
|
|
||||||
.transform(multi)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp);
|
|
||||||
|
|
||||||
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
|
|
||||||
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0, 1, 0, 0, 0, 0});
|
|
||||||
INDArray exp1 = Nd4j.create(new float[]{0, 0, 0, 1, 1, 0, 0, 0, 1, 1});
|
|
||||||
|
|
||||||
assertEquals(exp0, arr0);
|
|
||||||
assertEquals(exp1, arr1);
|
|
||||||
|
|
||||||
|
|
||||||
String json = tp.toJson();
|
|
||||||
TransformProcess tp2 = TransformProcess.fromJson(json);
|
|
||||||
assertEquals(tp, tp2);
|
|
||||||
|
|
||||||
List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp);
|
|
||||||
INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get();
|
|
||||||
INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
assertEquals(exp0, arr0a);
|
|
||||||
assertEquals(exp1, arr1a);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,414 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.nlp.transforms;
|
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
import org.datavec.nlp.metadata.VocabCache;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizer.preprocessor.LowerCasePreProcessor;
|
|
||||||
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
|
||||||
import org.datavec.nlp.vectorizer.TfidfVectorizer;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Triple;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
import static org.datavec.nlp.vectorizer.TextVectorizer.*;
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class TokenizerBagOfWordsTermSequenceIndexTransformTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSequenceExecution() {
|
|
||||||
//credit: https://stackoverflow.com/questions/23792781/tf-idf-feature-weights-using-sklearn-feature-extraction-text-tfidfvectorizer
|
|
||||||
String[] corpus = {
|
|
||||||
"This is very strange".toLowerCase(),
|
|
||||||
"This is very nice".toLowerCase()
|
|
||||||
};
|
|
||||||
//{'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0}
|
|
||||||
|
|
||||||
/*
|
|
||||||
## Reproduce with:
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
corpus = ["This is very strange", "This is very nice"]
|
|
||||||
|
|
||||||
## SMOOTH = FALSE case:
|
|
||||||
vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False)
|
|
||||||
X = vectorizer.fit_transform(corpus)
|
|
||||||
idf = vectorizer.idf_
|
|
||||||
print(dict(zip(vectorizer.get_feature_names(), idf)))
|
|
||||||
|
|
||||||
newText = ["This is very strange", "This is very nice"]
|
|
||||||
out = vectorizer.transform(newText)
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
{'is': 1.0, 'nice': 1.6931471805599454, 'strange': 1.6931471805599454, 'this': 1.0, 'very': 1.0}
|
|
||||||
(0, 4) 1.0
|
|
||||||
(0, 3) 1.0
|
|
||||||
(0, 2) 1.6931471805599454
|
|
||||||
(0, 0) 1.0
|
|
||||||
(1, 4) 1.0
|
|
||||||
(1, 3) 1.0
|
|
||||||
(1, 1) 1.6931471805599454
|
|
||||||
(1, 0) 1.0
|
|
||||||
|
|
||||||
## SMOOTH + TRUE case:
|
|
||||||
{'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0}
|
|
||||||
(0, 4) 1.0
|
|
||||||
(0, 3) 1.0
|
|
||||||
(0, 2) 1.4054651081081644
|
|
||||||
(0, 0) 1.0
|
|
||||||
(1, 4) 1.0
|
|
||||||
(1, 3) 1.0
|
|
||||||
(1, 1) 1.4054651081081644
|
|
||||||
(1, 0) 1.0
|
|
||||||
*/
|
|
||||||
|
|
||||||
List<List<List<Writable>>> input = new ArrayList<>();
|
|
||||||
input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1]))));
|
|
||||||
|
|
||||||
// First: Check TfidfVectorizer vs. scikit:
|
|
||||||
|
|
||||||
Map<String,Double> idfMapNoSmooth = new HashMap<>();
|
|
||||||
idfMapNoSmooth.put("is",1.0);
|
|
||||||
idfMapNoSmooth.put("nice",1.6931471805599454);
|
|
||||||
idfMapNoSmooth.put("strange",1.6931471805599454);
|
|
||||||
idfMapNoSmooth.put("this",1.0);
|
|
||||||
idfMapNoSmooth.put("very",1.0);
|
|
||||||
|
|
||||||
Map<String,Double> idfMapSmooth = new HashMap<>();
|
|
||||||
idfMapSmooth.put("is",1.0);
|
|
||||||
idfMapSmooth.put("nice",1.4054651081081644);
|
|
||||||
idfMapSmooth.put("strange",1.4054651081081644);
|
|
||||||
idfMapSmooth.put("this",1.0);
|
|
||||||
idfMapSmooth.put("very",1.0);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TfidfVectorizer tfidfVectorizer = new TfidfVectorizer();
|
|
||||||
Configuration configuration = new Configuration();
|
|
||||||
configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName());
|
|
||||||
configuration.set(MIN_WORD_FREQUENCY,"1");
|
|
||||||
configuration.set(STOP_WORDS,"");
|
|
||||||
configuration.set(TfidfVectorizer.SMOOTH_IDF, "false");
|
|
||||||
|
|
||||||
tfidfVectorizer.initialize(configuration);
|
|
||||||
|
|
||||||
CollectionRecordReader collectionRecordReader = new CollectionRecordReader(input.get(0));
|
|
||||||
INDArray array = tfidfVectorizer.fitTransform(collectionRecordReader);
|
|
||||||
|
|
||||||
INDArray expNoSmooth = Nd4j.create(DataType.FLOAT, 2, 5);
|
|
||||||
VocabCache vc = tfidfVectorizer.getCache();
|
|
||||||
expNoSmooth.putScalar(0, vc.wordIndex("very"), 1.0);
|
|
||||||
expNoSmooth.putScalar(0, vc.wordIndex("this"), 1.0);
|
|
||||||
expNoSmooth.putScalar(0, vc.wordIndex("strange"), 1.6931471805599454);
|
|
||||||
expNoSmooth.putScalar(0, vc.wordIndex("is"), 1.0);
|
|
||||||
|
|
||||||
expNoSmooth.putScalar(1, vc.wordIndex("very"), 1.0);
|
|
||||||
expNoSmooth.putScalar(1, vc.wordIndex("this"), 1.0);
|
|
||||||
expNoSmooth.putScalar(1, vc.wordIndex("nice"), 1.6931471805599454);
|
|
||||||
expNoSmooth.putScalar(1, vc.wordIndex("is"), 1.0);
|
|
||||||
|
|
||||||
assertEquals(expNoSmooth, array);
|
|
||||||
|
|
||||||
|
|
||||||
//------------------------------------------------------------
|
|
||||||
//Smooth version:
|
|
||||||
tfidfVectorizer = new TfidfVectorizer();
|
|
||||||
configuration = new Configuration();
|
|
||||||
configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName());
|
|
||||||
configuration.set(MIN_WORD_FREQUENCY,"1");
|
|
||||||
configuration.set(STOP_WORDS,"");
|
|
||||||
configuration.set(TfidfVectorizer.SMOOTH_IDF, "true");
|
|
||||||
|
|
||||||
tfidfVectorizer.initialize(configuration);
|
|
||||||
|
|
||||||
collectionRecordReader.reset();
|
|
||||||
array = tfidfVectorizer.fitTransform(collectionRecordReader);
|
|
||||||
|
|
||||||
INDArray expSmooth = Nd4j.create(DataType.FLOAT, 2, 5);
|
|
||||||
expSmooth.putScalar(0, vc.wordIndex("very"), 1.0);
|
|
||||||
expSmooth.putScalar(0, vc.wordIndex("this"), 1.0);
|
|
||||||
expSmooth.putScalar(0, vc.wordIndex("strange"), 1.4054651081081644);
|
|
||||||
expSmooth.putScalar(0, vc.wordIndex("is"), 1.0);
|
|
||||||
|
|
||||||
expSmooth.putScalar(1, vc.wordIndex("very"), 1.0);
|
|
||||||
expSmooth.putScalar(1, vc.wordIndex("this"), 1.0);
|
|
||||||
expSmooth.putScalar(1, vc.wordIndex("nice"), 1.4054651081081644);
|
|
||||||
expSmooth.putScalar(1, vc.wordIndex("is"), 1.0);
|
|
||||||
|
|
||||||
assertEquals(expSmooth, array);
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
//Second: Check transform vs scikit/TfidfVectorizer
|
|
||||||
|
|
||||||
List<String> vocab = new ArrayList<>(5); //Arrays.asList("is","nice","strange","this","very");
|
|
||||||
for( int i=0; i<5; i++ ){
|
|
||||||
vocab.add(vc.wordAt(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
String inputColumnName = "input";
|
|
||||||
String outputColumnName = "output";
|
|
||||||
Map<String,Integer> wordIndexMap = new HashMap<>();
|
|
||||||
for(int i = 0; i < vocab.size(); i++) {
|
|
||||||
wordIndexMap.put(vocab.get(i),i);
|
|
||||||
}
|
|
||||||
|
|
||||||
TokenizerBagOfWordsTermSequenceIndexTransform tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform(
|
|
||||||
inputColumnName,
|
|
||||||
outputColumnName,
|
|
||||||
wordIndexMap,
|
|
||||||
idfMapNoSmooth,
|
|
||||||
false,
|
|
||||||
null, null);
|
|
||||||
|
|
||||||
SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder();
|
|
||||||
sequenceSchemaBuilder.addColumnString("input");
|
|
||||||
SequenceSchema schema = sequenceSchemaBuilder.build();
|
|
||||||
assertEquals("input",schema.getName(0));
|
|
||||||
|
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
|
||||||
.transform(tokenizerBagOfWordsTermSequenceIndexTransform)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//System.out.println(execute);
|
|
||||||
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
|
|
||||||
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
assertEquals(expNoSmooth.getRow(0, true), arr0);
|
|
||||||
assertEquals(expNoSmooth.getRow(1, true), arr1);
|
|
||||||
|
|
||||||
|
|
||||||
//--------------------------------
|
|
||||||
//Check smooth:
|
|
||||||
|
|
||||||
tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform(
|
|
||||||
inputColumnName,
|
|
||||||
outputColumnName,
|
|
||||||
wordIndexMap,
|
|
||||||
idfMapSmooth,
|
|
||||||
false,
|
|
||||||
null, null);
|
|
||||||
|
|
||||||
schema = (SequenceSchema) new SequenceSchema.Builder().addColumnString("input").build();
|
|
||||||
|
|
||||||
transformProcess = new TransformProcess.Builder(schema)
|
|
||||||
.transform(tokenizerBagOfWordsTermSequenceIndexTransform)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess);
|
|
||||||
|
|
||||||
arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
|
|
||||||
arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
assertEquals(expSmooth.getRow(0, true), arr0);
|
|
||||||
assertEquals(expSmooth.getRow(1, true), arr1);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//Test JSON serialization:
|
|
||||||
|
|
||||||
String json = transformProcess.toJson();
|
|
||||||
TransformProcess fromJson = TransformProcess.fromJson(json);
|
|
||||||
assertEquals(transformProcess, fromJson);
|
|
||||||
List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, fromJson);
|
|
||||||
|
|
||||||
INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get();
|
|
||||||
INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
assertEquals(expSmooth.getRow(0, true), arr0a);
|
|
||||||
assertEquals(expSmooth.getRow(1, true), arr1a);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void additionalTest(){
|
|
||||||
/*
|
|
||||||
## To reproduce:
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
corpus = [
|
|
||||||
'This is the first document',
|
|
||||||
'This document is the second document',
|
|
||||||
'And this is the third one',
|
|
||||||
'Is this the first document',
|
|
||||||
]
|
|
||||||
vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False)
|
|
||||||
X = vectorizer.fit_transform(corpus)
|
|
||||||
print(vectorizer.get_feature_names())
|
|
||||||
|
|
||||||
out = vectorizer.transform(corpus)
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
|
|
||||||
(0, 8) 1.0
|
|
||||||
(0, 6) 1.0
|
|
||||||
(0, 3) 1.0
|
|
||||||
(0, 2) 1.6931471805599454
|
|
||||||
(0, 1) 1.2876820724517808
|
|
||||||
(1, 8) 1.0
|
|
||||||
(1, 6) 1.0
|
|
||||||
(1, 5) 2.386294361119891
|
|
||||||
(1, 3) 1.0
|
|
||||||
(1, 1) 2.5753641449035616
|
|
||||||
(2, 8) 1.0
|
|
||||||
(2, 7) 2.386294361119891
|
|
||||||
(2, 6) 1.0
|
|
||||||
(2, 4) 2.386294361119891
|
|
||||||
(2, 3) 1.0
|
|
||||||
(2, 0) 2.386294361119891
|
|
||||||
(3, 8) 1.0
|
|
||||||
(3, 6) 1.0
|
|
||||||
(3, 3) 1.0
|
|
||||||
(3, 2) 1.6931471805599454
|
|
||||||
(3, 1) 1.2876820724517808
|
|
||||||
{'and': 2.386294361119891, 'document': 1.2876820724517808, 'first': 1.6931471805599454, 'is': 1.0, 'one': 2.386294361119891, 'second': 2.386294361119891, 'the': 1.0, 'third': 2.386294361119891, 'this': 1.0}
|
|
||||||
*/
|
|
||||||
|
|
||||||
String[] corpus = {
|
|
||||||
"This is the first document",
|
|
||||||
"This document is the second document",
|
|
||||||
"And this is the third one",
|
|
||||||
"Is this the first document"};
|
|
||||||
|
|
||||||
TfidfVectorizer tfidfVectorizer = new TfidfVectorizer();
|
|
||||||
Configuration configuration = new Configuration();
|
|
||||||
configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName());
|
|
||||||
configuration.set(MIN_WORD_FREQUENCY,"1");
|
|
||||||
configuration.set(STOP_WORDS,"");
|
|
||||||
configuration.set(TfidfVectorizer.SMOOTH_IDF, "false");
|
|
||||||
configuration.set(PREPROCESSOR, LowerCasePreProcessor.class.getName());
|
|
||||||
|
|
||||||
tfidfVectorizer.initialize(configuration);
|
|
||||||
|
|
||||||
List<List<List<Writable>>> input = new ArrayList<>();
|
|
||||||
//input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1]))));
|
|
||||||
List<List<Writable>> seq = new ArrayList<>();
|
|
||||||
for(String s : corpus){
|
|
||||||
seq.add(Collections.<Writable>singletonList(new Text(s)));
|
|
||||||
}
|
|
||||||
input.add(seq);
|
|
||||||
|
|
||||||
CollectionRecordReader crr = new CollectionRecordReader(seq);
|
|
||||||
INDArray arr = tfidfVectorizer.fitTransform(crr);
|
|
||||||
|
|
||||||
//System.out.println(arr);
|
|
||||||
assertArrayEquals(new long[]{4, 9}, arr.shape());
|
|
||||||
|
|
||||||
List<String> pyVocab = Arrays.asList("and", "document", "first", "is", "one", "second", "the", "third", "this");
|
|
||||||
List<Triple<Integer,Integer,Double>> l = new ArrayList<>();
|
|
||||||
l.add(new Triple<>(0, 8, 1.0));
|
|
||||||
l.add(new Triple<>(0, 6, 1.0));
|
|
||||||
l.add(new Triple<>(0, 3, 1.0));
|
|
||||||
l.add(new Triple<>(0, 2, 1.6931471805599454));
|
|
||||||
l.add(new Triple<>(0, 1, 1.2876820724517808));
|
|
||||||
l.add(new Triple<>(1, 8, 1.0));
|
|
||||||
l.add(new Triple<>(1, 6, 1.0));
|
|
||||||
l.add(new Triple<>(1, 5, 2.386294361119891));
|
|
||||||
l.add(new Triple<>(1, 3, 1.0));
|
|
||||||
l.add(new Triple<>(1, 1, 2.5753641449035616));
|
|
||||||
l.add(new Triple<>(2, 8, 1.0));
|
|
||||||
l.add(new Triple<>(2, 7, 2.386294361119891));
|
|
||||||
l.add(new Triple<>(2, 6, 1.0));
|
|
||||||
l.add(new Triple<>(2, 4, 2.386294361119891));
|
|
||||||
l.add(new Triple<>(2, 3, 1.0));
|
|
||||||
l.add(new Triple<>(2, 0, 2.386294361119891));
|
|
||||||
l.add(new Triple<>(3, 8, 1.0));
|
|
||||||
l.add(new Triple<>(3, 6, 1.0));
|
|
||||||
l.add(new Triple<>(3, 3, 1.0));
|
|
||||||
l.add(new Triple<>(3, 2, 1.6931471805599454));
|
|
||||||
l.add(new Triple<>(3, 1, 1.2876820724517808));
|
|
||||||
|
|
||||||
INDArray exp = Nd4j.create(DataType.FLOAT, 4, 9);
|
|
||||||
for(Triple<Integer,Integer,Double> t : l){
|
|
||||||
//Work out work index, accounting for different vocab/word orders:
|
|
||||||
int wIdx = tfidfVectorizer.getCache().wordIndex(pyVocab.get(t.getSecond()));
|
|
||||||
exp.putScalar(t.getFirst(), wIdx, t.getThird());
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(exp, arr);
|
|
||||||
|
|
||||||
|
|
||||||
Map<String,Double> idfWeights = new HashMap<>();
|
|
||||||
idfWeights.put("and", 2.386294361119891);
|
|
||||||
idfWeights.put("document", 1.2876820724517808);
|
|
||||||
idfWeights.put("first", 1.6931471805599454);
|
|
||||||
idfWeights.put("is", 1.0);
|
|
||||||
idfWeights.put("one", 2.386294361119891);
|
|
||||||
idfWeights.put("second", 2.386294361119891);
|
|
||||||
idfWeights.put("the", 1.0);
|
|
||||||
idfWeights.put("third", 2.386294361119891);
|
|
||||||
idfWeights.put("this", 1.0);
|
|
||||||
|
|
||||||
|
|
||||||
List<String> vocab = new ArrayList<>(9); //Arrays.asList("is","nice","strange","this","very");
|
|
||||||
for( int i=0; i<9; i++ ){
|
|
||||||
vocab.add(tfidfVectorizer.getCache().wordAt(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
String inputColumnName = "input";
|
|
||||||
String outputColumnName = "output";
|
|
||||||
Map<String,Integer> wordIndexMap = new HashMap<>();
|
|
||||||
for(int i = 0; i < vocab.size(); i++) {
|
|
||||||
wordIndexMap.put(vocab.get(i),i);
|
|
||||||
}
|
|
||||||
|
|
||||||
TokenizerBagOfWordsTermSequenceIndexTransform transform = new TokenizerBagOfWordsTermSequenceIndexTransform(
|
|
||||||
inputColumnName,
|
|
||||||
outputColumnName,
|
|
||||||
wordIndexMap,
|
|
||||||
idfWeights,
|
|
||||||
false,
|
|
||||||
null, LowerCasePreProcessor.class.getName());
|
|
||||||
|
|
||||||
SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder();
|
|
||||||
sequenceSchemaBuilder.addColumnString("input");
|
|
||||||
SequenceSchema schema = sequenceSchemaBuilder.build();
|
|
||||||
assertEquals("input",schema.getName(0));
|
|
||||||
|
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
|
||||||
.transform(transform)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess);
|
|
||||||
|
|
||||||
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
|
|
||||||
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
|
|
||||||
|
|
||||||
assertEquals(exp.getRow(0, true), arr0);
|
|
||||||
assertEquals(exp.getRow(1, true), arr1);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,53 +0,0 @@
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<configuration>
|
|
||||||
|
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
|
||||||
<file>logs/application.log</file>
|
|
||||||
<encoder>
|
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
|
||||||
%n%message%n%xException%n</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
|
|
||||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
|
||||||
<encoder>
|
|
||||||
<pattern> %logger{15} - %message%n%xException{5}
|
|
||||||
</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
|
|
||||||
<logger name="org.apache.catalina.core" level="DEBUG" />
|
|
||||||
<logger name="org.springframework" level="DEBUG" />
|
|
||||||
<logger name="org.datavec" level="DEBUG" />
|
|
||||||
<logger name="org.nd4j" level="INFO" />
|
|
||||||
<logger name="opennlp.uima.util" level="OFF" />
|
|
||||||
<logger name="org.apache.uima" level="OFF" />
|
|
||||||
<logger name="org.cleartk" level="OFF" />
|
|
||||||
<logger name="org.apache.spark" level="WARN" />
|
|
||||||
|
|
||||||
|
|
||||||
<root level="ERROR">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
<appender-ref ref="FILE" />
|
|
||||||
</root>
|
|
||||||
|
|
||||||
</configuration>
|
|
|
@ -1,56 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-geo</artifactId>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-api</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.maxmind.geoip2</groupId>
|
|
||||||
<artifactId>geoip2</artifactId>
|
|
||||||
<version>${geoip2.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.geo;
|
|
||||||
|
|
||||||
public enum LocationType {
|
|
||||||
CITY, CITY_ID, CONTINENT, CONTINENT_ID, COUNTRY, COUNTRY_ID, COORDINATES, POSTAL_CODE, SUBDIVISIONS, SUBDIVISIONS_ID
|
|
||||||
}
|
|
|
@ -1,194 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.reduce.geo;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.datavec.api.transform.ReduceOp;
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
|
||||||
import org.datavec.api.transform.metadata.StringMetaData;
|
|
||||||
import org.datavec.api.transform.ops.IAggregableReduceOp;
|
|
||||||
import org.datavec.api.transform.reduce.AggregableColumnReduction;
|
|
||||||
import org.datavec.api.transform.reduce.AggregableReductionUtils;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.common.function.Supplier;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class CoordinatesReduction implements AggregableColumnReduction {
|
|
||||||
public static final String DEFAULT_COLUMN_NAME = "CoordinatesReduction";
|
|
||||||
|
|
||||||
public final static String DEFAULT_DELIMITER = ":";
|
|
||||||
protected String delimiter = DEFAULT_DELIMITER;
|
|
||||||
|
|
||||||
private final List<String> columnNamesPostReduce;
|
|
||||||
|
|
||||||
private final Supplier<IAggregableReduceOp<Writable, List<Writable>>> multiOp(final List<ReduceOp> ops) {
|
|
||||||
return new Supplier<IAggregableReduceOp<Writable, List<Writable>>>() {
|
|
||||||
@Override
|
|
||||||
public IAggregableReduceOp<Writable, List<Writable>> get() {
|
|
||||||
return AggregableReductionUtils.reduceDoubleColumn(ops, false, null);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
public CoordinatesReduction(String columnNamePostReduce, ReduceOp op) {
|
|
||||||
this(columnNamePostReduce, op, DEFAULT_DELIMITER);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CoordinatesReduction(List<String> columnNamePostReduce, List<ReduceOp> op) {
|
|
||||||
this(columnNamePostReduce, op, DEFAULT_DELIMITER);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CoordinatesReduction(String columnNamePostReduce, ReduceOp op, String delimiter) {
|
|
||||||
this(Collections.singletonList(columnNamePostReduce), Collections.singletonList(op), delimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CoordinatesReduction(List<String> columnNamesPostReduce, List<ReduceOp> ops, String delimiter) {
|
|
||||||
this.columnNamesPostReduce = columnNamesPostReduce;
|
|
||||||
this.reducer = new CoordinateAggregableReduceOp(ops.size(), multiOp(ops), delimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getColumnsOutputName(String columnInputName) {
|
|
||||||
return columnNamesPostReduce;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ColumnMetaData> getColumnOutputMetaData(List<String> newColumnName, ColumnMetaData columnInputMeta) {
|
|
||||||
List<ColumnMetaData> res = new ArrayList<>(newColumnName.size());
|
|
||||||
for (String cn : newColumnName)
|
|
||||||
res.add(new StringMetaData((cn)));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Schema transform(Schema inputSchema) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setInputSchema(Schema inputSchema) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Schema getInputSchema() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String outputColumnName() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] outputColumnNames() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] columnNames() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String columnName() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
private IAggregableReduceOp<Writable, List<Writable>> reducer;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IAggregableReduceOp<Writable, List<Writable>> reduceOp() {
|
|
||||||
return reducer;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static class CoordinateAggregableReduceOp implements IAggregableReduceOp<Writable, List<Writable>> {
|
|
||||||
|
|
||||||
|
|
||||||
private int nOps;
|
|
||||||
private Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOpValue;
|
|
||||||
@Getter
|
|
||||||
private ArrayList<IAggregableReduceOp<Writable, List<Writable>>> perCoordinateOps; // of size coords()
|
|
||||||
private String delimiter;
|
|
||||||
|
|
||||||
public CoordinateAggregableReduceOp(int n, Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOp,
|
|
||||||
String delim) {
|
|
||||||
this.nOps = n;
|
|
||||||
this.perCoordinateOps = new ArrayList<>();
|
|
||||||
this.initialOpValue = initialOp;
|
|
||||||
this.delimiter = delim;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public <W extends IAggregableReduceOp<Writable, List<Writable>>> void combine(W accu) {
|
|
||||||
if (accu instanceof CoordinateAggregableReduceOp) {
|
|
||||||
CoordinateAggregableReduceOp accumulator = (CoordinateAggregableReduceOp) accu;
|
|
||||||
for (int i = 0; i < Math.min(perCoordinateOps.size(), accumulator.getPerCoordinateOps().size()); i++) {
|
|
||||||
perCoordinateOps.get(i).combine(accumulator.getPerCoordinateOps().get(i));
|
|
||||||
} // the rest is assumed identical
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void accept(Writable writable) {
|
|
||||||
String[] coordinates = writable.toString().split(delimiter);
|
|
||||||
for (int i = 0; i < coordinates.length; i++) {
|
|
||||||
String coordinate = coordinates[i];
|
|
||||||
while (perCoordinateOps.size() < i + 1) {
|
|
||||||
perCoordinateOps.add(initialOpValue.get());
|
|
||||||
}
|
|
||||||
perCoordinateOps.get(i).accept(new DoubleWritable(Double.parseDouble(coordinate)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Writable> get() {
|
|
||||||
List<StringBuilder> res = new ArrayList<>(nOps);
|
|
||||||
for (int i = 0; i < nOps; i++) {
|
|
||||||
res.add(new StringBuilder());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < perCoordinateOps.size(); i++) {
|
|
||||||
List<Writable> resThisCoord = perCoordinateOps.get(i).get();
|
|
||||||
for (int j = 0; j < nOps; j++) {
|
|
||||||
res.get(j).append(resThisCoord.get(j).toString());
|
|
||||||
if (i < perCoordinateOps.size() - 1) {
|
|
||||||
res.get(j).append(delimiter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
List<Writable> finalRes = new ArrayList<>(nOps);
|
|
||||||
for (StringBuilder sb : res) {
|
|
||||||
finalRes.add(new Text(sb.toString()));
|
|
||||||
}
|
|
||||||
return finalRes;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,117 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.transform.geo;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.MathOp;
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
|
||||||
import org.datavec.api.transform.metadata.DoubleMetaData;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.transform.BaseColumnsMathOpTransform;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class CoordinatesDistanceTransform extends BaseColumnsMathOpTransform {
|
|
||||||
|
|
||||||
public final static String DEFAULT_DELIMITER = ":";
|
|
||||||
protected String delimiter = DEFAULT_DELIMITER;
|
|
||||||
|
|
||||||
public CoordinatesDistanceTransform(String newColumnName, String firstColumn, String secondColumn,
|
|
||||||
String stdevColumn) {
|
|
||||||
this(newColumnName, firstColumn, secondColumn, stdevColumn, DEFAULT_DELIMITER);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CoordinatesDistanceTransform(@JsonProperty("newColumnName") String newColumnName,
|
|
||||||
@JsonProperty("firstColumn") String firstColumn, @JsonProperty("secondColumn") String secondColumn,
|
|
||||||
@JsonProperty("stdevColumn") String stdevColumn, @JsonProperty("delimiter") String delimiter) {
|
|
||||||
super(newColumnName, MathOp.Add /* dummy op */,
|
|
||||||
stdevColumn != null ? new String[] {firstColumn, secondColumn, stdevColumn}
|
|
||||||
: new String[] {firstColumn, secondColumn});
|
|
||||||
this.delimiter = delimiter;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected ColumnMetaData derivedColumnMetaData(String newColumnName, Schema inputSchema) {
|
|
||||||
return new DoubleMetaData(newColumnName);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Writable doOp(Writable... input) {
|
|
||||||
String[] first = input[0].toString().split(delimiter);
|
|
||||||
String[] second = input[1].toString().split(delimiter);
|
|
||||||
String[] stdev = columns.length > 2 ? input[2].toString().split(delimiter) : null;
|
|
||||||
|
|
||||||
double dist = 0;
|
|
||||||
for (int i = 0; i < first.length; i++) {
|
|
||||||
double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]);
|
|
||||||
double s = stdev != null ? Double.parseDouble(stdev[i]) : 1;
|
|
||||||
dist += (d * d) / (s * s);
|
|
||||||
}
|
|
||||||
return new DoubleWritable(Math.sqrt(dist));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "CoordinatesDistanceTransform(newColumnName=\"" + newColumnName + "\",columns="
|
|
||||||
+ Arrays.toString(columns) + ",delimiter=" + delimiter + ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform an object
|
|
||||||
* in to another object
|
|
||||||
*
|
|
||||||
* @param input the record to transform
|
|
||||||
* @return the transformed writable
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Object map(Object input) {
|
|
||||||
List row = (List) input;
|
|
||||||
String[] first = row.get(0).toString().split(delimiter);
|
|
||||||
String[] second = row.get(1).toString().split(delimiter);
|
|
||||||
String[] stdev = columns.length > 2 ? row.get(2).toString().split(delimiter) : null;
|
|
||||||
|
|
||||||
double dist = 0;
|
|
||||||
for (int i = 0; i < first.length; i++) {
|
|
||||||
double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]);
|
|
||||||
double s = stdev != null ? Double.parseDouble(stdev[i]) : 1;
|
|
||||||
dist += (d * d) / (s * s);
|
|
||||||
}
|
|
||||||
return Math.sqrt(dist);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transform a sequence
|
|
||||||
*
|
|
||||||
* @param sequence
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Object mapSequence(Object sequence) {
|
|
||||||
List<List> seq = (List<List>) sequence;
|
|
||||||
List<Double> ret = new ArrayList<>();
|
|
||||||
for (Object step : seq)
|
|
||||||
ret.add((Double) map(step));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,70 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.transform.geo;
|
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
import org.nd4j.common.util.ArchiveUtils;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.URL;
|
|
||||||
|
|
||||||
public class GeoIPFetcher {
|
|
||||||
protected static final Logger log = LoggerFactory.getLogger(GeoIPFetcher.class);
|
|
||||||
|
|
||||||
/** Default directory for <a href="http://dev.maxmind.com/geoip/geoipupdate/">http://dev.maxmind.com/geoip/geoipupdate/</a> */
|
|
||||||
public static final String GEOIP_DIR = "/usr/local/share/GeoIP/";
|
|
||||||
public static final String GEOIP_DIR2 = System.getProperty("user.home") + "/.datavec-geoip";
|
|
||||||
|
|
||||||
public static final String CITY_DB = "GeoIP2-City.mmdb";
|
|
||||||
public static final String CITY_LITE_DB = "GeoLite2-City.mmdb";
|
|
||||||
|
|
||||||
public static final String CITY_LITE_URL =
|
|
||||||
"http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz";
|
|
||||||
|
|
||||||
public static synchronized File fetchCityDB() throws IOException {
|
|
||||||
File cityFile = new File(GEOIP_DIR, CITY_DB);
|
|
||||||
if (cityFile.isFile()) {
|
|
||||||
return cityFile;
|
|
||||||
}
|
|
||||||
cityFile = new File(GEOIP_DIR, CITY_LITE_DB);
|
|
||||||
if (cityFile.isFile()) {
|
|
||||||
return cityFile;
|
|
||||||
}
|
|
||||||
cityFile = new File(GEOIP_DIR2, CITY_LITE_DB);
|
|
||||||
if (cityFile.isFile()) {
|
|
||||||
return cityFile;
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info("Downloading GeoLite2 City database...");
|
|
||||||
File archive = new File(GEOIP_DIR2, CITY_LITE_DB + ".gz");
|
|
||||||
File dir = new File(GEOIP_DIR2);
|
|
||||||
dir.mkdirs();
|
|
||||||
FileUtils.copyURLToFile(new URL(CITY_LITE_URL), archive);
|
|
||||||
ArchiveUtils.unzipFileTo(archive.getAbsolutePath(), dir.getAbsolutePath());
|
|
||||||
Preconditions.checkState(cityFile.isFile(), "Error extracting files: expected city file does not exist after extraction");
|
|
||||||
|
|
||||||
return cityFile;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.transform.geo;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.geo.LocationType;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
public class IPAddressToCoordinatesTransform extends IPAddressToLocationTransform {
|
|
||||||
|
|
||||||
public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName) throws IOException {
|
|
||||||
this(columnName, DEFAULT_DELIMITER);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName,
|
|
||||||
@JsonProperty("delimiter") String delimiter) throws IOException {
|
|
||||||
super(columnName, LocationType.COORDINATES, delimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "IPAddressToCoordinatesTransform";
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,184 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.transform.geo;
|
|
||||||
|
|
||||||
import com.maxmind.geoip2.DatabaseReader;
|
|
||||||
import com.maxmind.geoip2.exception.GeoIp2Exception;
|
|
||||||
import com.maxmind.geoip2.model.CityResponse;
|
|
||||||
import com.maxmind.geoip2.record.Location;
|
|
||||||
import com.maxmind.geoip2.record.Subdivision;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.datavec.api.transform.geo.LocationType;
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
|
||||||
import org.datavec.api.transform.metadata.StringMetaData;
|
|
||||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.ObjectInputStream;
|
|
||||||
import java.io.ObjectOutputStream;
|
|
||||||
import java.net.InetAddress;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class IPAddressToLocationTransform extends BaseColumnTransform {
|
|
||||||
/**
|
|
||||||
* Name of the system property to use when configuring the GeoIP database file.<br>
|
|
||||||
* Most users don't need to set this - typically used for testing purposes.<br>
|
|
||||||
* Set with the full local path, like: "C:/datavec-geo/GeoIP2-City-Test.mmdb"
|
|
||||||
*/
|
|
||||||
public static final String GEOIP_FILE_PROPERTY = "org.datavec.geoip.file";
|
|
||||||
|
|
||||||
private static File database;
|
|
||||||
private static DatabaseReader reader;
|
|
||||||
|
|
||||||
public final static String DEFAULT_DELIMITER = ":";
|
|
||||||
protected String delimiter = DEFAULT_DELIMITER;
|
|
||||||
protected LocationType locationType;
|
|
||||||
|
|
||||||
private static synchronized void init() throws IOException {
|
|
||||||
// A File object pointing to your GeoIP2 or GeoLite2 database:
|
|
||||||
// http://dev.maxmind.com/geoip/geoip2/geolite2/
|
|
||||||
if (database == null) {
|
|
||||||
String s = System.getProperty(GEOIP_FILE_PROPERTY);
|
|
||||||
if(s != null && !s.isEmpty()){
|
|
||||||
//Use user-specified GEOIP file - mainly for testing purposes
|
|
||||||
File f = new File(s);
|
|
||||||
if(f.exists() && f.isFile()){
|
|
||||||
database = f;
|
|
||||||
} else {
|
|
||||||
log.warn("GeoIP file (system property {}) is set to \"{}\" but this is not a valid file, using default database", GEOIP_FILE_PROPERTY, s);
|
|
||||||
database = GeoIPFetcher.fetchCityDB();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
database = GeoIPFetcher.fetchCityDB();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This creates the DatabaseReader object, which should be reused across lookups.
|
|
||||||
if (reader == null) {
|
|
||||||
reader = new DatabaseReader.Builder(database).build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public IPAddressToLocationTransform(String columnName) throws IOException {
|
|
||||||
this(columnName, LocationType.CITY);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IPAddressToLocationTransform(String columnName, LocationType locationType) throws IOException {
|
|
||||||
this(columnName, locationType, DEFAULT_DELIMITER);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IPAddressToLocationTransform(@JsonProperty("columnName") String columnName,
|
|
||||||
@JsonProperty("delimiter") LocationType locationType, @JsonProperty("delimiter") String delimiter)
|
|
||||||
throws IOException {
|
|
||||||
super(columnName);
|
|
||||||
this.delimiter = delimiter;
|
|
||||||
this.locationType = locationType;
|
|
||||||
init();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
|
|
||||||
return new StringMetaData(newName); //Output after transform: String (Text)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Writable map(Writable columnWritable) {
|
|
||||||
try {
|
|
||||||
InetAddress ipAddress = InetAddress.getByName(columnWritable.toString());
|
|
||||||
CityResponse response = reader.city(ipAddress);
|
|
||||||
String text = "";
|
|
||||||
switch (locationType) {
|
|
||||||
case CITY:
|
|
||||||
text = response.getCity().getName();
|
|
||||||
break;
|
|
||||||
case CITY_ID:
|
|
||||||
text = response.getCity().getGeoNameId().toString();
|
|
||||||
break;
|
|
||||||
case CONTINENT:
|
|
||||||
text = response.getContinent().getName();
|
|
||||||
break;
|
|
||||||
case CONTINENT_ID:
|
|
||||||
text = response.getContinent().getGeoNameId().toString();
|
|
||||||
break;
|
|
||||||
case COUNTRY:
|
|
||||||
text = response.getCountry().getName();
|
|
||||||
break;
|
|
||||||
case COUNTRY_ID:
|
|
||||||
text = response.getCountry().getGeoNameId().toString();
|
|
||||||
break;
|
|
||||||
case COORDINATES:
|
|
||||||
Location location = response.getLocation();
|
|
||||||
text = location.getLatitude() + delimiter + location.getLongitude();
|
|
||||||
break;
|
|
||||||
case POSTAL_CODE:
|
|
||||||
text = response.getPostal().getCode();
|
|
||||||
break;
|
|
||||||
case SUBDIVISIONS:
|
|
||||||
for (Subdivision s : response.getSubdivisions()) {
|
|
||||||
if (text.length() > 0) {
|
|
||||||
text += delimiter;
|
|
||||||
}
|
|
||||||
text += s.getName();
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case SUBDIVISIONS_ID:
|
|
||||||
for (Subdivision s : response.getSubdivisions()) {
|
|
||||||
if (text.length() > 0) {
|
|
||||||
text += delimiter;
|
|
||||||
}
|
|
||||||
text += s.getGeoNameId().toString();
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert false;
|
|
||||||
}
|
|
||||||
if(text == null)
|
|
||||||
text = "";
|
|
||||||
return new Text(text);
|
|
||||||
} catch (GeoIp2Exception | IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "IPAddressToLocationTransform";
|
|
||||||
}
|
|
||||||
|
|
||||||
//Custom serialization methods, because GeoIP2 doesn't allow DatabaseReader objects to be serialized :(
|
|
||||||
private void writeObject(ObjectOutputStream out) throws IOException {
|
|
||||||
out.defaultWriteObject();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
|
|
||||||
in.defaultReadObject();
|
|
||||||
init();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object map(Object input) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
package org.datavec.api.transform;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.datavec.api.transform";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseND4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,80 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.reduce;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
|
||||||
import org.datavec.api.transform.ReduceOp;
|
|
||||||
import org.datavec.api.transform.ops.IAggregableReduceOp;
|
|
||||||
import org.datavec.api.transform.reduce.geo.CoordinatesReduction;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author saudet
|
|
||||||
*/
|
|
||||||
public class TestGeoReduction {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCustomReductions() {
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1#5")));
|
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2#6")));
|
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3#7")));
|
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4#8")));
|
|
||||||
|
|
||||||
List<Writable> expected = Arrays.asList((Writable) new Text("someKey"), new Text("10.0#26.0"));
|
|
||||||
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("key").addColumnString("coord").build();
|
|
||||||
|
|
||||||
Reducer reducer = new Reducer.Builder(ReduceOp.Count).keyColumns("key")
|
|
||||||
.customReduction("coord", new CoordinatesReduction("coordSum", ReduceOp.Sum, "#")).build();
|
|
||||||
|
|
||||||
reducer.setInputSchema(schema);
|
|
||||||
|
|
||||||
IAggregableReduceOp<List<Writable>, List<Writable>> aggregableReduceOp = reducer.aggregableReducer();
|
|
||||||
for (List<Writable> l : inputs)
|
|
||||||
aggregableReduceOp.accept(l);
|
|
||||||
List<Writable> out = aggregableReduceOp.get();
|
|
||||||
|
|
||||||
assertEquals(2, out.size());
|
|
||||||
assertEquals(expected, out);
|
|
||||||
|
|
||||||
//Check schema:
|
|
||||||
String[] expNames = new String[] {"key", "coordSum"};
|
|
||||||
ColumnType[] expTypes = new ColumnType[] {ColumnType.String, ColumnType.String};
|
|
||||||
Schema outSchema = reducer.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(2, outSchema.numColumns());
|
|
||||||
for (int i = 0; i < 2; i++) {
|
|
||||||
assertEquals(expNames[i], outSchema.getName(i));
|
|
||||||
assertEquals(expTypes[i], outSchema.getType(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,153 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
|
||||||
import org.datavec.api.transform.Transform;
|
|
||||||
import org.datavec.api.transform.geo.LocationType;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author saudet
|
|
||||||
*/
|
|
||||||
public class TestGeoTransforms {
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void beforeClass() throws Exception {
|
|
||||||
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
|
||||||
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void afterClass(){
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCoordinatesDistanceTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
assertEquals(4, out.numColumns());
|
|
||||||
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
|
||||||
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
|
||||||
out.getColumnTypes());
|
|
||||||
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
|
||||||
new DoubleWritable(Math.sqrt(160))),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
|
||||||
new Text("10|5"))));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToCoordinatesTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
double latitude = 51.5142;
|
|
||||||
double longitude = -0.0931;
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
|
|
||||||
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
|
||||||
oos.writeObject(transform);
|
|
||||||
|
|
||||||
byte[] bytes = baos.toByteArray();
|
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
||||||
ObjectInputStream ois = new ObjectInputStream(bais);
|
|
||||||
|
|
||||||
Transform deserialized = (Transform) ois.readObject();
|
|
||||||
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
//System.out.println(Arrays.toString(coordinates));
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToLocationTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
LocationType[] locationTypes = LocationType.values();
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
|
|
||||||
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
|
|
||||||
|
|
||||||
for (int i = 0; i < locationTypes.length; i++) {
|
|
||||||
LocationType locationType = locationTypes[i];
|
|
||||||
String location = locations[i];
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToLocationTransform("column", locationType);
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
assertEquals(location, writables.get(0).toString());
|
|
||||||
//System.out.println(location);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -37,11 +37,7 @@
|
||||||
<name>datavec-data</name>
|
<name>datavec-data</name>
|
||||||
|
|
||||||
<modules>
|
<modules>
|
||||||
<module>datavec-data-audio</module>
|
|
||||||
<module>datavec-data-codec</module>
|
|
||||||
<module>datavec-data-image</module>
|
<module>datavec-data-image</module>
|
||||||
<module>datavec-data-nlp</module>
|
|
||||||
<module>datavec-geo</module>
|
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
<dependencyManagement>
|
<dependencyManagement>
|
||||||
|
|
Loading…
Reference in New Issue