Remove datavec audio/nlp

master
agibsonccc 2021-03-06 08:52:27 +09:00
parent ee06fdd16f
commit 8e8a5ec369
98 changed files with 0 additions and 10355 deletions

View File

@ -1,77 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-data</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-data-audio</artifactId>
<name>datavec-data-audio</name>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacpp</artifactId>
<version>${javacpp.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv</artifactId>
<version>${javacv.version}</version>
</dependency>
<dependency>
<groupId>com.github.wendykierp</groupId>
<artifactId>JTransforms</artifactId>
<version>${jtransforms.version}</version>
<classifier>with-dependencies</classifier>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!--
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>ffmpeg-platform</artifactId>
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
</dependency>
-->
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,329 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import org.datavec.audio.extension.NormalizedSampleAmplitudes;
import org.datavec.audio.extension.Spectrogram;
import org.datavec.audio.fingerprint.FingerprintManager;
import org.datavec.audio.fingerprint.FingerprintSimilarity;
import org.datavec.audio.fingerprint.FingerprintSimilarityComputer;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
public class Wave implements Serializable {
private static final long serialVersionUID = 1L;
private WaveHeader waveHeader;
private byte[] data; // little endian
private byte[] fingerprint;
/**
* Constructor
*
*/
public Wave() {
this.waveHeader = new WaveHeader();
this.data = new byte[0];
}
/**
* Constructor
*
* @param filename
* Wave file
*/
public Wave(String filename) {
try {
InputStream inputStream = new FileInputStream(filename);
initWaveWithInputStream(inputStream);
inputStream.close();
} catch (IOException e) {
System.out.println(e.toString());
}
}
/**
* Constructor
*
* @param inputStream
* Wave file input stream
*/
public Wave(InputStream inputStream) {
initWaveWithInputStream(inputStream);
}
/**
* Constructor
*
* @param waveHeader
* @param data
*/
public Wave(WaveHeader waveHeader, byte[] data) {
this.waveHeader = waveHeader;
this.data = data;
}
private void initWaveWithInputStream(InputStream inputStream) {
// reads the first 44 bytes for header
waveHeader = new WaveHeader(inputStream);
if (waveHeader.isValid()) {
// load data
try {
data = new byte[inputStream.available()];
inputStream.read(data);
} catch (IOException e) {
System.err.println(e.toString());
}
// end load data
} else {
System.err.println("Invalid Wave Header");
}
}
/**
* Trim the wave data
*
* @param leftTrimNumberOfSample
* Number of sample trimmed from beginning
* @param rightTrimNumberOfSample
* Number of sample trimmed from ending
*/
public void trim(int leftTrimNumberOfSample, int rightTrimNumberOfSample) {
long chunkSize = waveHeader.getChunkSize();
long subChunk2Size = waveHeader.getSubChunk2Size();
long totalTrimmed = leftTrimNumberOfSample + rightTrimNumberOfSample;
if (totalTrimmed > subChunk2Size) {
leftTrimNumberOfSample = (int) subChunk2Size;
}
// update wav info
chunkSize -= totalTrimmed;
subChunk2Size -= totalTrimmed;
if (chunkSize >= 0 && subChunk2Size >= 0) {
waveHeader.setChunkSize(chunkSize);
waveHeader.setSubChunk2Size(subChunk2Size);
byte[] trimmedData = new byte[(int) subChunk2Size];
System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size);
data = trimmedData;
} else {
System.err.println("Trim error: Negative length");
}
}
/**
* Trim the wave data from beginning
*
* @param numberOfSample
* numberOfSample trimmed from beginning
*/
public void leftTrim(int numberOfSample) {
trim(numberOfSample, 0);
}
/**
* Trim the wave data from ending
*
* @param numberOfSample
* numberOfSample trimmed from ending
*/
public void rightTrim(int numberOfSample) {
trim(0, numberOfSample);
}
/**
* Trim the wave data
*
* @param leftTrimSecond
* Seconds trimmed from beginning
* @param rightTrimSecond
* Seconds trimmed from ending
*/
public void trim(double leftTrimSecond, double rightTrimSecond) {
int sampleRate = waveHeader.getSampleRate();
int bitsPerSample = waveHeader.getBitsPerSample();
int channels = waveHeader.getChannels();
int leftTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * leftTrimSecond);
int rightTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * rightTrimSecond);
trim(leftTrimNumberOfSample, rightTrimNumberOfSample);
}
/**
* Trim the wave data from beginning
*
* @param second
* Seconds trimmed from beginning
*/
public void leftTrim(double second) {
trim(second, 0);
}
/**
* Trim the wave data from ending
*
* @param second
* Seconds trimmed from ending
*/
public void rightTrim(double second) {
trim(0, second);
}
/**
* Get the wave header
*
* @return waveHeader
*/
public WaveHeader getWaveHeader() {
return waveHeader;
}
/**
* Get the wave spectrogram
*
* @return spectrogram
*/
public Spectrogram getSpectrogram() {
return new Spectrogram(this);
}
/**
* Get the wave spectrogram
*
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
*
* @return spectrogram
*/
public Spectrogram getSpectrogram(int fftSampleSize, int overlapFactor) {
return new Spectrogram(this, fftSampleSize, overlapFactor);
}
/**
* Get the wave data in bytes
*
* @return wave data
*/
public byte[] getBytes() {
return data;
}
/**
* Data byte size of the wave excluding header size
*
* @return byte size of the wave
*/
public int size() {
return data.length;
}
/**
* Length of the wave in second
*
* @return length in second
*/
public float length() {
return (float) waveHeader.getSubChunk2Size() / waveHeader.getByteRate();
}
/**
* Timestamp of the wave length
*
* @return timestamp
*/
public String timestamp() {
float totalSeconds = this.length();
float second = totalSeconds % 60;
int minute = (int) totalSeconds / 60 % 60;
int hour = (int) (totalSeconds / 3600);
StringBuilder sb = new StringBuilder();
if (hour > 0) {
sb.append(hour + ":");
}
if (minute > 0) {
sb.append(minute + ":");
}
sb.append(second);
return sb.toString();
}
/**
* Get the amplitudes of the wave samples (depends on the header)
*
* @return amplitudes array (signed 16-bit)
*/
public short[] getSampleAmplitudes() {
int bytePerSample = waveHeader.getBitsPerSample() / 8;
int numSamples = data.length / bytePerSample;
short[] amplitudes = new short[numSamples];
int pointer = 0;
for (int i = 0; i < numSamples; i++) {
short amplitude = 0;
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
// little endian
amplitude |= (short) ((data[pointer++] & 0xFF) << (byteNumber * 8));
}
amplitudes[i] = amplitude;
}
return amplitudes;
}
public String toString() {
StringBuilder sb = new StringBuilder(waveHeader.toString());
sb.append("\n");
sb.append("length: " + timestamp());
return sb.toString();
}
public double[] getNormalizedAmplitudes() {
NormalizedSampleAmplitudes amplitudes = new NormalizedSampleAmplitudes(this);
return amplitudes.getNormalizedAmplitudes();
}
public byte[] getFingerprint() {
if (fingerprint == null) {
FingerprintManager fingerprintManager = new FingerprintManager();
fingerprint = fingerprintManager.extractFingerprint(this);
}
return fingerprint;
}
public FingerprintSimilarity getFingerprintSimilarity(Wave wave) {
FingerprintSimilarityComputer fingerprintSimilarityComputer =
new FingerprintSimilarityComputer(this.getFingerprint(), wave.getFingerprint());
return fingerprintSimilarityComputer.getFingerprintsSimilarity();
}
}

View File

@ -1,98 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.FileOutputStream;
import java.io.IOException;
@Slf4j
public class WaveFileManager {
private Wave wave;
public WaveFileManager() {
wave = new Wave();
}
public WaveFileManager(Wave wave) {
setWave(wave);
}
/**
* Save the wave file
*
* @param filename
* filename to be saved
*
* @see Wave file saved
*/
public void saveWaveAsFile(String filename) {
WaveHeader waveHeader = wave.getWaveHeader();
int byteRate = waveHeader.getByteRate();
int audioFormat = waveHeader.getAudioFormat();
int sampleRate = waveHeader.getSampleRate();
int bitsPerSample = waveHeader.getBitsPerSample();
int channels = waveHeader.getChannels();
long chunkSize = waveHeader.getChunkSize();
long subChunk1Size = waveHeader.getSubChunk1Size();
long subChunk2Size = waveHeader.getSubChunk2Size();
int blockAlign = waveHeader.getBlockAlign();
try {
FileOutputStream fos = new FileOutputStream(filename);
fos.write(WaveHeader.RIFF_HEADER.getBytes());
// little endian
fos.write(new byte[] {(byte) (chunkSize), (byte) (chunkSize >> 8), (byte) (chunkSize >> 16),
(byte) (chunkSize >> 24)});
fos.write(WaveHeader.WAVE_HEADER.getBytes());
fos.write(WaveHeader.FMT_HEADER.getBytes());
fos.write(new byte[] {(byte) (subChunk1Size), (byte) (subChunk1Size >> 8), (byte) (subChunk1Size >> 16),
(byte) (subChunk1Size >> 24)});
fos.write(new byte[] {(byte) (audioFormat), (byte) (audioFormat >> 8)});
fos.write(new byte[] {(byte) (channels), (byte) (channels >> 8)});
fos.write(new byte[] {(byte) (sampleRate), (byte) (sampleRate >> 8), (byte) (sampleRate >> 16),
(byte) (sampleRate >> 24)});
fos.write(new byte[] {(byte) (byteRate), (byte) (byteRate >> 8), (byte) (byteRate >> 16),
(byte) (byteRate >> 24)});
fos.write(new byte[] {(byte) (blockAlign), (byte) (blockAlign >> 8)});
fos.write(new byte[] {(byte) (bitsPerSample), (byte) (bitsPerSample >> 8)});
fos.write(WaveHeader.DATA_HEADER.getBytes());
fos.write(new byte[] {(byte) (subChunk2Size), (byte) (subChunk2Size >> 8), (byte) (subChunk2Size >> 16),
(byte) (subChunk2Size >> 24)});
fos.write(wave.getBytes());
fos.close();
} catch (IOException e) {
log.error("",e);
}
}
public Wave getWave() {
return wave;
}
public void setWave(Wave wave) {
this.wave = wave;
}
}

View File

@ -1,281 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.io.InputStream;
@Slf4j
public class WaveHeader {
public static final String RIFF_HEADER = "RIFF";
public static final String WAVE_HEADER = "WAVE";
public static final String FMT_HEADER = "fmt ";
public static final String DATA_HEADER = "data";
public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header
private boolean valid;
private String chunkId; // 4 bytes
private long chunkSize; // unsigned 4 bytes, little endian
private String format; // 4 bytes
private String subChunk1Id; // 4 bytes
private long subChunk1Size; // unsigned 4 bytes, little endian
private int audioFormat; // unsigned 2 bytes, little endian
private int channels; // unsigned 2 bytes, little endian
private long sampleRate; // unsigned 4 bytes, little endian
private long byteRate; // unsigned 4 bytes, little endian
private int blockAlign; // unsigned 2 bytes, little endian
private int bitsPerSample; // unsigned 2 bytes, little endian
private String subChunk2Id; // 4 bytes
private long subChunk2Size; // unsigned 4 bytes, little endian
public WaveHeader() {
// init a 8k 16bit mono wav
chunkSize = 36;
subChunk1Size = 16;
audioFormat = 1;
channels = 1;
sampleRate = 8000;
byteRate = 16000;
blockAlign = 2;
bitsPerSample = 16;
subChunk2Size = 0;
valid = true;
}
public WaveHeader(InputStream inputStream) {
valid = loadHeader(inputStream);
}
private boolean loadHeader(InputStream inputStream) {
byte[] headerBuffer = new byte[HEADER_BYTE_LENGTH];
try {
inputStream.read(headerBuffer);
// read header
int pointer = 0;
chunkId = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++]});
// little endian
chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff << 24);
format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++]});
subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++], headerBuffer[pointer++]});
subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8);
subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++],
headerBuffer[pointer++], headerBuffer[pointer++]});
subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8
| (long) (headerBuffer[pointer++] & 0xff) << 16
| (long) (headerBuffer[pointer++] & 0xff) << 24;
// end read header
// the inputStream should be closed outside this method
// dis.close();
} catch (IOException e) {
log.error("",e);
return false;
}
if (bitsPerSample != 8 && bitsPerSample != 16) {
System.err.println("WaveHeader: only supports bitsPerSample 8 or 16");
return false;
}
// check the format is support
if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) {
return true;
} else {
System.err.println("WaveHeader: Unsupported header format");
}
return false;
}
public boolean isValid() {
return valid;
}
public String getChunkId() {
return chunkId;
}
public long getChunkSize() {
return chunkSize;
}
public String getFormat() {
return format;
}
public String getSubChunk1Id() {
return subChunk1Id;
}
public long getSubChunk1Size() {
return subChunk1Size;
}
public int getAudioFormat() {
return audioFormat;
}
public int getChannels() {
return channels;
}
public int getSampleRate() {
return (int) sampleRate;
}
public int getByteRate() {
return (int) byteRate;
}
public int getBlockAlign() {
return blockAlign;
}
public int getBitsPerSample() {
return bitsPerSample;
}
public String getSubChunk2Id() {
return subChunk2Id;
}
public long getSubChunk2Size() {
return subChunk2Size;
}
public void setSampleRate(int sampleRate) {
int newSubChunk2Size = (int) (this.subChunk2Size * sampleRate / this.sampleRate);
// if num bytes for each sample is even, the size of newSubChunk2Size also needed to be in even number
if ((bitsPerSample / 8) % 2 == 0) {
if (newSubChunk2Size % 2 != 0) {
newSubChunk2Size++;
}
}
this.sampleRate = sampleRate;
this.byteRate = sampleRate * bitsPerSample / 8;
this.chunkSize = newSubChunk2Size + 36;
this.subChunk2Size = newSubChunk2Size;
}
public void setChunkId(String chunkId) {
this.chunkId = chunkId;
}
public void setChunkSize(long chunkSize) {
this.chunkSize = chunkSize;
}
public void setFormat(String format) {
this.format = format;
}
public void setSubChunk1Id(String subChunk1Id) {
this.subChunk1Id = subChunk1Id;
}
public void setSubChunk1Size(long subChunk1Size) {
this.subChunk1Size = subChunk1Size;
}
public void setAudioFormat(int audioFormat) {
this.audioFormat = audioFormat;
}
public void setChannels(int channels) {
this.channels = channels;
}
public void setByteRate(long byteRate) {
this.byteRate = byteRate;
}
public void setBlockAlign(int blockAlign) {
this.blockAlign = blockAlign;
}
public void setBitsPerSample(int bitsPerSample) {
this.bitsPerSample = bitsPerSample;
}
public void setSubChunk2Id(String subChunk2Id) {
this.subChunk2Id = subChunk2Id;
}
public void setSubChunk2Size(long subChunk2Size) {
this.subChunk2Size = subChunk2Size;
}
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("chunkId: " + chunkId);
sb.append("\n");
sb.append("chunkSize: " + chunkSize);
sb.append("\n");
sb.append("format: " + format);
sb.append("\n");
sb.append("subChunk1Id: " + subChunk1Id);
sb.append("\n");
sb.append("subChunk1Size: " + subChunk1Size);
sb.append("\n");
sb.append("audioFormat: " + audioFormat);
sb.append("\n");
sb.append("channels: " + channels);
sb.append("\n");
sb.append("sampleRate: " + sampleRate);
sb.append("\n");
sb.append("byteRate: " + byteRate);
sb.append("\n");
sb.append("blockAlign: " + blockAlign);
sb.append("\n");
sb.append("bitsPerSample: " + bitsPerSample);
sb.append("\n");
sb.append("subChunk2Id: " + subChunk2Id);
sb.append("\n");
sb.append("subChunk2Size: " + subChunk2Size);
return sb.toString();
}
}

View File

@ -1,81 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
import org.jtransforms.fft.DoubleFFT_1D;
public class FastFourierTransform {
/**
* Get the frequency intensities
*
* @param amplitudes amplitudes of the signal. Format depends on value of complex
* @param complex if true, amplitudes is assumed to be complex interlaced (re = even, im = odd), if false amplitudes
* are assumed to be real valued.
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
*/
public double[] getMagnitudes(double[] amplitudes, boolean complex) {
final int sampleSize = amplitudes.length;
final int nrofFrequencyBins = sampleSize / 2;
// call the fft and transform the complex numbers
if (complex) {
DoubleFFT_1D fft = new DoubleFFT_1D(nrofFrequencyBins);
fft.complexForward(amplitudes);
} else {
DoubleFFT_1D fft = new DoubleFFT_1D(sampleSize);
fft.realForward(amplitudes);
// amplitudes[1] contains re[sampleSize/2] or im[(sampleSize-1) / 2] (depending on whether sampleSize is odd or even)
// Discard it as it is useless without the other part
// im part dc bin is always 0 for real input
amplitudes[1] = 0;
}
// end call the fft and transform the complex numbers
// even indexes (0,2,4,6,...) are real parts
// odd indexes (1,3,5,7,...) are img parts
double[] mag = new double[nrofFrequencyBins];
for (int i = 0; i < nrofFrequencyBins; i++) {
final int f = 2 * i;
mag[i] = Math.sqrt(amplitudes[f] * amplitudes[f] + amplitudes[f + 1] * amplitudes[f + 1]);
}
return mag;
}
/**
* Get the frequency intensities. Backwards compatible with previous versions w.r.t to number of frequency bins.
* Use getMagnitudes(amplitudes, true) to get all bins.
*
* @param amplitudes complex-valued signal to transform. Even indexes are real and odd indexes are img
* @return intensities of each frequency unit: mag[frequency_unit]=intensity
*/
public double[] getMagnitudes(double[] amplitudes) {
double[] magnitudes = getMagnitudes(amplitudes, true);
double[] halfOfMagnitudes = new double[magnitudes.length/2];
System.arraycopy(magnitudes, 0,halfOfMagnitudes, 0, halfOfMagnitudes.length);
return halfOfMagnitudes;
}
}

View File

@ -1,66 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
public class LinearInterpolation {
public LinearInterpolation() {
}
/**
* Do interpolation on the samples according to the original and destinated sample rates
*
* @param oldSampleRate sample rate of the original samples
* @param newSampleRate sample rate of the interpolated samples
* @param samples original samples
* @return interpolated samples
*/
public short[] interpolate(int oldSampleRate, int newSampleRate, short[] samples) {
if (oldSampleRate == newSampleRate) {
return samples;
}
int newLength = Math.round(((float) samples.length / oldSampleRate * newSampleRate));
float lengthMultiplier = (float) newLength / samples.length;
short[] interpolatedSamples = new short[newLength];
// interpolate the value by the linear equation y=mx+c
for (int i = 0; i < newLength; i++) {
// get the nearest positions for the interpolated point
float currentPosition = i / lengthMultiplier;
int nearestLeftPosition = (int) currentPosition;
int nearestRightPosition = nearestLeftPosition + 1;
if (nearestRightPosition >= samples.length) {
nearestRightPosition = samples.length - 1;
}
float slope = samples[nearestRightPosition] - samples[nearestLeftPosition]; // delta x is 1
float positionFromLeft = currentPosition - nearestLeftPosition;
interpolatedSamples[i] = (short) (slope * positionFromLeft + samples[nearestLeftPosition]); // y=mx+c
}
return interpolatedSamples;
}
}

View File

@ -1,84 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
public class Resampler {
public Resampler() {}
/**
* Do resampling. Currently the amplitude is stored by short such that maximum bitsPerSample is 16 (bytePerSample is 2)
*
* @param sourceData The source data in bytes
* @param bitsPerSample How many bits represents one sample (currently supports max. bitsPerSample=16)
* @param sourceRate Sample rate of the source data
* @param targetRate Sample rate of the target data
* @return re-sampled data
*/
public byte[] reSample(byte[] sourceData, int bitsPerSample, int sourceRate, int targetRate) {
// make the bytes to amplitudes first
int bytePerSample = bitsPerSample / 8;
int numSamples = sourceData.length / bytePerSample;
short[] amplitudes = new short[numSamples]; // 16 bit, use a short to store
int pointer = 0;
for (int i = 0; i < numSamples; i++) {
short amplitude = 0;
for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) {
// little endian
amplitude |= (short) ((sourceData[pointer++] & 0xFF) << (byteNumber * 8));
}
amplitudes[i] = amplitude;
}
// end make the amplitudes
// do interpolation
LinearInterpolation reSample = new LinearInterpolation();
short[] targetSample = reSample.interpolate(sourceRate, targetRate, amplitudes);
int targetLength = targetSample.length;
// end do interpolation
// TODO: Remove the high frequency signals with a digital filter, leaving a signal containing only half-sample-rated frequency information, but still sampled at a rate of target sample rate. Usually FIR is used
// end resample the amplitudes
// convert the amplitude to bytes
byte[] bytes;
if (bytePerSample == 1) {
bytes = new byte[targetLength];
for (int i = 0; i < targetLength; i++) {
bytes[i] = (byte) targetSample[i];
}
} else {
// suppose bytePerSample==2
bytes = new byte[targetLength * 2];
for (int i = 0; i < targetSample.length; i++) {
// little endian
bytes[i * 2] = (byte) (targetSample[i] & 0xff);
bytes[i * 2 + 1] = (byte) ((targetSample[i] >> 8) & 0xff);
}
}
// end convert the amplitude to bytes
return bytes;
}
}

View File

@ -1,95 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;
public class WindowFunction {
public static final int RECTANGULAR = 0;
public static final int BARTLETT = 1;
public static final int HANNING = 2;
public static final int HAMMING = 3;
public static final int BLACKMAN = 4;
int windowType = 0; // defaults to rectangular window
public WindowFunction() {}
public void setWindowType(int wt) {
windowType = wt;
}
public void setWindowType(String w) {
if (w.toUpperCase().equals("RECTANGULAR"))
windowType = RECTANGULAR;
if (w.toUpperCase().equals("BARTLETT"))
windowType = BARTLETT;
if (w.toUpperCase().equals("HANNING"))
windowType = HANNING;
if (w.toUpperCase().equals("HAMMING"))
windowType = HAMMING;
if (w.toUpperCase().equals("BLACKMAN"))
windowType = BLACKMAN;
}
public int getWindowType() {
return windowType;
}
/**
* Generate a window
*
* @param nSamples size of the window
* @return window in array
*/
public double[] generate(int nSamples) {
// generate nSamples window function values
// for index values 0 .. nSamples - 1
int m = nSamples / 2;
double r;
double pi = Math.PI;
double[] w = new double[nSamples];
switch (windowType) {
case BARTLETT: // Bartlett (triangular) window
for (int n = 0; n < nSamples; n++)
w[n] = 1.0f - Math.abs(n - m) / m;
break;
case HANNING: // Hanning window
r = pi / (m + 1);
for (int n = -m; n < m; n++)
w[m + n] = 0.5f + 0.5f * Math.cos(n * r);
break;
case HAMMING: // Hamming window
r = pi / m;
for (int n = -m; n < m; n++)
w[m + n] = 0.54f + 0.46f * Math.cos(n * r);
break;
case BLACKMAN: // Blackman window
r = pi / m;
for (int n = -m; n < m; n++)
w[m + n] = 0.42f + 0.5f * Math.cos(n * r) + 0.08f * Math.cos(2 * n * r);
break;
default: // Rectangular window function
for (int n = 0; n < nSamples; n++)
w[n] = 1.0f;
}
return w;
}
}

View File

@ -1,21 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.dsp;

View File

@ -1,67 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.extension;
import org.datavec.audio.Wave;
public class NormalizedSampleAmplitudes {
private Wave wave;
private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame
public NormalizedSampleAmplitudes(Wave wave) {
this.wave = wave;
}
/**
*
* Get normalized amplitude of each frame
*
* @return array of normalized amplitudes(signed 16 bit): normalizedAmplitudes[frame]=amplitude
*/
public double[] getNormalizedAmplitudes() {
if (normalizedAmplitudes == null) {
boolean signed = true;
// usually 8bit is unsigned
if (wave.getWaveHeader().getBitsPerSample() == 8) {
signed = false;
}
short[] amplitudes = wave.getSampleAmplitudes();
int numSamples = amplitudes.length;
int maxAmplitude = 1 << (wave.getWaveHeader().getBitsPerSample() - 1);
if (!signed) { // one more bit for unsigned value
maxAmplitude <<= 1;
}
normalizedAmplitudes = new double[numSamples];
for (int i = 0; i < numSamples; i++) {
normalizedAmplitudes[i] = (double) amplitudes[i] / maxAmplitude;
}
}
return normalizedAmplitudes;
}
}

View File

@ -1,214 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.extension;
import org.datavec.audio.Wave;
import org.datavec.audio.dsp.FastFourierTransform;
import org.datavec.audio.dsp.WindowFunction;
public class Spectrogram {
public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024;
public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping
private Wave wave;
private double[][] spectrogram; // relative spectrogram
private double[][] absoluteSpectrogram; // absolute spectrogram
private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2
private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping
private int numFrames; // number of frames of the spectrogram
private int framesPerSecond; // frame per second of the spectrogram
private int numFrequencyUnit; // number of y-axis unit
private double unitFrequency; // frequency per y-axis unit
/**
* Constructor
*
* @param wave
*/
public Spectrogram(Wave wave) {
this.wave = wave;
// default
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
this.overlapFactor = SPECTROGRAM_DEFAULT_OVERLAP_FACTOR;
buildSpectrogram();
}
/**
* Constructor
*
* @param wave
* @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2
* @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping
*/
public Spectrogram(Wave wave, int fftSampleSize, int overlapFactor) {
this.wave = wave;
if (Integer.bitCount(fftSampleSize) == 1) {
this.fftSampleSize = fftSampleSize;
} else {
System.err.print("The input number must be a power of 2");
this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE;
}
this.overlapFactor = overlapFactor;
buildSpectrogram();
}
/**
* Build spectrogram
*/
private void buildSpectrogram() {
short[] amplitudes = wave.getSampleAmplitudes();
int numSamples = amplitudes.length;
int pointer = 0;
// overlapping
if (overlapFactor > 1) {
int numOverlappedSamples = numSamples * overlapFactor;
int backSamples = fftSampleSize * (overlapFactor - 1) / overlapFactor;
short[] overlapAmp = new short[numOverlappedSamples];
pointer = 0;
for (int i = 0; i < amplitudes.length; i++) {
overlapAmp[pointer++] = amplitudes[i];
if (pointer % fftSampleSize == 0) {
// overlap
i -= backSamples;
}
}
numSamples = numOverlappedSamples;
amplitudes = overlapAmp;
}
// end overlapping
numFrames = numSamples / fftSampleSize;
framesPerSecond = (int) (numFrames / wave.length());
// set signals for fft
WindowFunction window = new WindowFunction();
window.setWindowType("Hamming");
double[] win = window.generate(fftSampleSize);
double[][] signals = new double[numFrames][];
for (int f = 0; f < numFrames; f++) {
signals[f] = new double[fftSampleSize];
int startSample = f * fftSampleSize;
for (int n = 0; n < fftSampleSize; n++) {
signals[f][n] = amplitudes[startSample + n] * win[n];
}
}
// end set signals for fft
absoluteSpectrogram = new double[numFrames][];
// for each frame in signals, do fft on it
FastFourierTransform fft = new FastFourierTransform();
for (int i = 0; i < numFrames; i++) {
absoluteSpectrogram[i] = fft.getMagnitudes(signals[i], false);
}
if (absoluteSpectrogram.length > 0) {
numFrequencyUnit = absoluteSpectrogram[0].length;
unitFrequency = (double) wave.getWaveHeader().getSampleRate() / 2 / numFrequencyUnit; // frequency could be caught within the half of nSamples according to Nyquist theory
// normalization of absoultSpectrogram
spectrogram = new double[numFrames][numFrequencyUnit];
// set max and min amplitudes
double maxAmp = Double.MIN_VALUE;
double minAmp = Double.MAX_VALUE;
for (int i = 0; i < numFrames; i++) {
for (int j = 0; j < numFrequencyUnit; j++) {
if (absoluteSpectrogram[i][j] > maxAmp) {
maxAmp = absoluteSpectrogram[i][j];
} else if (absoluteSpectrogram[i][j] < minAmp) {
minAmp = absoluteSpectrogram[i][j];
}
}
}
// end set max and min amplitudes
// normalization
// avoiding divided by zero
double minValidAmp = 0.00000000001F;
if (minAmp == 0) {
minAmp = minValidAmp;
}
double diff = Math.log10(maxAmp / minAmp); // perceptual difference
for (int i = 0; i < numFrames; i++) {
for (int j = 0; j < numFrequencyUnit; j++) {
if (absoluteSpectrogram[i][j] < minValidAmp) {
spectrogram[i][j] = 0;
} else {
spectrogram[i][j] = (Math.log10(absoluteSpectrogram[i][j] / minAmp)) / diff;
}
}
}
// end normalization
}
}
/**
* Get spectrogram: spectrogram[time][frequency]=intensity
*
* @return logarithm normalized spectrogram
*/
public double[][] getNormalizedSpectrogramData() {
return spectrogram;
}
/**
* Get spectrogram: spectrogram[time][frequency]=intensity
*
* @return absolute spectrogram
*/
public double[][] getAbsoluteSpectrogramData() {
return absoluteSpectrogram;
}
public int getNumFrames() {
return numFrames;
}
public int getFramesPerSecond() {
return framesPerSecond;
}
public int getNumFrequencyUnit() {
return numFrequencyUnit;
}
public double getUnitFrequency() {
return unitFrequency;
}
public int getFftSampleSize() {
return fftSampleSize;
}
public int getOverlapFactor() {
return overlapFactor;
}
}

View File

@ -1,272 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import lombok.extern.slf4j.Slf4j;
import org.datavec.audio.Wave;
import org.datavec.audio.WaveHeader;
import org.datavec.audio.dsp.Resampler;
import org.datavec.audio.extension.Spectrogram;
import org.datavec.audio.processor.TopManyPointsProcessorChain;
import org.datavec.audio.properties.FingerprintProperties;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
@Slf4j
public class FingerprintManager {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame();
private int overlapFactor = fingerprintProperties.getOverlapFactor();
private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame();
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
/**
* Constructor
*/
public FingerprintManager() {
}
/**
* Extract fingerprint from Wave object
*
* @param wave Wave Object to be extracted fingerprint
* @return fingerprint in bytes
*/
public byte[] extractFingerprint(Wave wave) {
int[][] coordinates; // coordinates[x][0..3]=y0..y3
byte[] fingerprint = new byte[0];
// resample to target rate
Resampler resampler = new Resampler();
int sourceRate = wave.getWaveHeader().getSampleRate();
int targetRate = fingerprintProperties.getSampleRate();
byte[] resampledWaveData = resampler.reSample(wave.getBytes(), wave.getWaveHeader().getBitsPerSample(),
sourceRate, targetRate);
// update the wave header
WaveHeader resampledWaveHeader = wave.getWaveHeader();
resampledWaveHeader.setSampleRate(targetRate);
// make resampled wave
Wave resampledWave = new Wave(resampledWaveHeader, resampledWaveData);
// end resample to target rate
// get spectrogram's data
Spectrogram spectrogram = resampledWave.getSpectrogram(sampleSizePerFrame, overlapFactor);
double[][] spectorgramData = spectrogram.getNormalizedSpectrogramData();
List<Integer>[] pointsLists = getRobustPointList(spectorgramData);
int numFrames = pointsLists.length;
// prepare fingerprint bytes
coordinates = new int[numFrames][numRobustPointsPerFrame];
for (int x = 0; x < numFrames; x++) {
if (pointsLists[x].size() == numRobustPointsPerFrame) {
Iterator<Integer> pointsListsIterator = pointsLists[x].iterator();
for (int y = 0; y < numRobustPointsPerFrame; y++) {
coordinates[x][y] = pointsListsIterator.next();
}
} else {
// use -1 to fill the empty byte
for (int y = 0; y < numRobustPointsPerFrame; y++) {
coordinates[x][y] = -1;
}
}
}
// end make fingerprint
// for each valid coordinate, append with its intensity
List<Byte> byteList = new LinkedList<Byte>();
for (int i = 0; i < numFrames; i++) {
for (int j = 0; j < numRobustPointsPerFrame; j++) {
if (coordinates[i][j] != -1) {
// first 2 bytes is x
byteList.add((byte) (i >> 8));
byteList.add((byte) i);
// next 2 bytes is y
int y = coordinates[i][j];
byteList.add((byte) (y >> 8));
byteList.add((byte) y);
// next 4 bytes is intensity
int intensity = (int) (spectorgramData[i][y] * Integer.MAX_VALUE); // spectorgramData is ranged from 0~1
byteList.add((byte) (intensity >> 24));
byteList.add((byte) (intensity >> 16));
byteList.add((byte) (intensity >> 8));
byteList.add((byte) intensity);
}
}
}
// end for each valid coordinate, append with its intensity
fingerprint = new byte[byteList.size()];
Iterator<Byte> byteListIterator = byteList.iterator();
int pointer = 0;
while (byteListIterator.hasNext()) {
fingerprint[pointer++] = byteListIterator.next();
}
return fingerprint;
}
/**
* Get bytes from fingerprint file
*
* @param fingerprintFile fingerprint filename
* @return fingerprint in bytes
*/
public byte[] getFingerprintFromFile(String fingerprintFile) {
byte[] fingerprint = null;
try {
InputStream fis = new FileInputStream(fingerprintFile);
fingerprint = getFingerprintFromInputStream(fis);
fis.close();
} catch (IOException e) {
log.error("",e);
}
return fingerprint;
}
/**
* Get bytes from fingerprint inputstream
*
* @param inputStream fingerprint inputstream
* @return fingerprint in bytes
*/
public byte[] getFingerprintFromInputStream(InputStream inputStream) {
byte[] fingerprint = null;
try {
fingerprint = new byte[inputStream.available()];
inputStream.read(fingerprint);
} catch (IOException e) {
log.error("",e);
}
return fingerprint;
}
/**
* Save fingerprint to a file
*
* @param fingerprint fingerprint bytes
* @param filename fingerprint filename
* @see FingerprintManager file saved
*/
public void saveFingerprintAsFile(byte[] fingerprint, String filename) {
FileOutputStream fileOutputStream;
try {
fileOutputStream = new FileOutputStream(filename);
fileOutputStream.write(fingerprint);
fileOutputStream.close();
} catch (IOException e) {
log.error("",e);
}
}
// robustLists[x]=y1,y2,y3,...
private List<Integer>[] getRobustPointList(double[][] spectrogramData) {
int numX = spectrogramData.length;
int numY = spectrogramData[0].length;
double[][] allBanksIntensities = new double[numX][numY];
int bandwidthPerBank = numY / numFilterBanks;
for (int b = 0; b < numFilterBanks; b++) {
double[][] bankIntensities = new double[numX][bandwidthPerBank];
for (int i = 0; i < numX; i++) {
System.arraycopy(spectrogramData[i], b * bandwidthPerBank, bankIntensities[i], 0, bandwidthPerBank);
}
// get the most robust point in each filter bank
TopManyPointsProcessorChain processorChain = new TopManyPointsProcessorChain(bankIntensities, 1);
double[][] processedIntensities = processorChain.getIntensities();
for (int i = 0; i < numX; i++) {
System.arraycopy(processedIntensities[i], 0, allBanksIntensities[i], b * bandwidthPerBank,
bandwidthPerBank);
}
}
List<int[]> robustPointList = new LinkedList<int[]>();
// find robust points
for (int i = 0; i < allBanksIntensities.length; i++) {
for (int j = 0; j < allBanksIntensities[i].length; j++) {
if (allBanksIntensities[i][j] > 0) {
int[] point = new int[] {i, j};
//System.out.println(i+","+frequency);
robustPointList.add(point);
}
}
}
// end find robust points
List<Integer>[] robustLists = new LinkedList[spectrogramData.length];
for (int i = 0; i < robustLists.length; i++) {
robustLists[i] = new LinkedList<>();
}
// robustLists[x]=y1,y2,y3,...
for (int[] coor : robustPointList) {
robustLists[coor[0]].add(coor[1]);
}
// return the list per frame
return robustLists;
}
/**
* Number of frames in a fingerprint
* Each frame lengths 8 bytes
* Usually there is more than one point in each frame, so it cannot simply divide the bytes length by 8
* Last 8 byte of thisFingerprint is the last frame of this wave
* First 2 byte of the last 8 byte is the x position of this wave, i.e. (number_of_frames-1) of this wave
*
* @param fingerprint fingerprint bytes
* @return number of frames of the fingerprint
*/
public static int getNumFrames(byte[] fingerprint) {
if (fingerprint.length < 8) {
return 0;
}
// get the last x-coordinate (length-8&length-7)bytes from fingerprint
return ((fingerprint[fingerprint.length - 8] & 0xff) << 8 | (fingerprint[fingerprint.length - 7] & 0xff)) + 1;
}
}

View File

@ -1,107 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import org.datavec.audio.properties.FingerprintProperties;
public class FingerprintSimilarity {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
private int mostSimilarFramePosition;
private float score;
private float similarity;
/**
* Constructor
*/
public FingerprintSimilarity() {
mostSimilarFramePosition = Integer.MIN_VALUE;
score = -1;
similarity = -1;
}
/**
* Get the most similar position in terms of frame number
*
* @return most similar frame position
*/
public int getMostSimilarFramePosition() {
return mostSimilarFramePosition;
}
/**
* Set the most similar position in terms of frame number
*
* @param mostSimilarFramePosition
*/
public void setMostSimilarFramePosition(int mostSimilarFramePosition) {
this.mostSimilarFramePosition = mostSimilarFramePosition;
}
/**
* Get the similarity of the fingerprints
* similarity from 0~1, which 0 means no similar feature is found and 1 means in average there is at least one match in every frame
*
* @return fingerprints similarity
*/
public float getSimilarity() {
return similarity;
}
/**
* Set the similarity of the fingerprints
*
* @param similarity similarity
*/
public void setSimilarity(float similarity) {
this.similarity = similarity;
}
/**
* Get the similarity score of the fingerprints
* Number of features found in the fingerprints per frame
*
* @return fingerprints similarity score
*/
public float getScore() {
return score;
}
/**
* Set the similarity score of the fingerprints
*
* @param score
*/
public void setScore(float score) {
this.score = score;
}
/**
* Get the most similar position in terms of time in second
*
* @return most similar starting time
*/
public float getsetMostSimilarTimePosition() {
return (float) mostSimilarFramePosition / fingerprintProperties.getNumRobustPointsPerFrame()
/ fingerprintProperties.getFps();
}
}

View File

@ -1,134 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.HashMap;
import java.util.List;
public class FingerprintSimilarityComputer {
private FingerprintSimilarity fingerprintSimilarity;
byte[] fingerprint1, fingerprint2;
/**
* Constructor, ready to compute the similarity of two fingerprints
*
* @param fingerprint1
* @param fingerprint2
*/
public FingerprintSimilarityComputer(byte[] fingerprint1, byte[] fingerprint2) {
this.fingerprint1 = fingerprint1;
this.fingerprint2 = fingerprint2;
fingerprintSimilarity = new FingerprintSimilarity();
}
/**
* Get fingerprint similarity of inout fingerprints
*
* @return fingerprint similarity object
*/
public FingerprintSimilarity getFingerprintsSimilarity() {
HashMap<Integer, Integer> offset_Score_Table = new HashMap<>(); // offset_Score_Table<offset,count>
int numFrames;
float score = 0;
int mostSimilarFramePosition = Integer.MIN_VALUE;
// one frame may contain several points, use the shorter one be the denominator
if (fingerprint1.length > fingerprint2.length) {
numFrames = FingerprintManager.getNumFrames(fingerprint2);
} else {
numFrames = FingerprintManager.getNumFrames(fingerprint1);
}
// get the pairs
PairManager pairManager = new PairManager();
HashMap<Integer, List<Integer>> this_Pair_PositionList_Table =
pairManager.getPair_PositionList_Table(fingerprint1);
HashMap<Integer, List<Integer>> compareWave_Pair_PositionList_Table =
pairManager.getPair_PositionList_Table(fingerprint2);
for (Integer compareWaveHashNumber : compareWave_Pair_PositionList_Table.keySet()) {
// if the compareWaveHashNumber doesn't exist in both tables, no need to compare
if (!this_Pair_PositionList_Table.containsKey(compareWaveHashNumber)
|| !compareWave_Pair_PositionList_Table.containsKey(compareWaveHashNumber)) {
continue;
}
// for each compare hash number, get the positions
List<Integer> wavePositionList = this_Pair_PositionList_Table.get(compareWaveHashNumber);
List<Integer> compareWavePositionList = compareWave_Pair_PositionList_Table.get(compareWaveHashNumber);
for (Integer thisPosition : wavePositionList) {
for (Integer compareWavePosition : compareWavePositionList) {
int offset = thisPosition - compareWavePosition;
if (offset_Score_Table.containsKey(offset)) {
offset_Score_Table.put(offset, offset_Score_Table.get(offset) + 1);
} else {
offset_Score_Table.put(offset, 1);
}
}
}
}
// map rank
MapRank mapRank = new MapRankInteger(offset_Score_Table, false);
// get the most similar positions and scores
List<Integer> orderedKeyList = mapRank.getOrderedKeyList(100, true);
if (orderedKeyList.size() > 0) {
int key = orderedKeyList.get(0);
// get the highest score position
mostSimilarFramePosition = key;
score = offset_Score_Table.get(key);
// accumulate the scores from neighbours
if (offset_Score_Table.containsKey(key - 1)) {
score += offset_Score_Table.get(key - 1) / 2;
}
if (offset_Score_Table.containsKey(key + 1)) {
score += offset_Score_Table.get(key + 1) / 2;
}
}
/*
Iterator<Integer> orderedKeyListIterator=orderedKeyList.iterator();
while (orderedKeyListIterator.hasNext()){
int offset=orderedKeyListIterator.next();
System.out.println(offset+": "+offset_Score_Table.get(offset));
}
*/
score /= numFrames;
float similarity = score;
// similarity >1 means in average there is at least one match in every frame
if (similarity > 1) {
similarity = 1;
}
fingerprintSimilarity.setMostSimilarFramePosition(mostSimilarFramePosition);
fingerprintSimilarity.setScore(score);
fingerprintSimilarity.setSimilarity(similarity);
return fingerprintSimilarity;
}
}

View File

@ -1,27 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.List;
public interface MapRank {
public List getOrderedKeyList(int numKeys, boolean sharpLimit);
}

View File

@ -1,179 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.*;
import java.util.Map.Entry;
public class MapRankDouble implements MapRank {
private Map map;
private boolean acsending = true;
public MapRankDouble(Map<?, Double> map, boolean acsending) {
this.map = map;
this.acsending = acsending;
}
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
Set mapEntrySet = map.entrySet();
List keyList = new LinkedList();
// if the numKeys is larger than map size, limit it
if (numKeys > map.size()) {
numKeys = map.size();
}
// end if the numKeys is larger than map size, limit it
if (map.size() > 0) {
double[] array = new double[map.size()];
int count = 0;
// get the pass values
Iterator<Entry> mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
array[count++] = (Double) entry.getValue();
}
// end get the pass values
int targetindex;
if (acsending) {
targetindex = numKeys;
} else {
targetindex = array.length - numKeys;
}
double passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
// get the passed keys and values
Map passedMap = new HashMap();
List<Double> valueList = new LinkedList<Double>();
mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
double value = (Double) entry.getValue();
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
passedMap.put(entry.getKey(), value);
valueList.add(value);
}
}
// end get the passed keys and values
// sort the value list
Double[] listArr = new Double[valueList.size()];
valueList.toArray(listArr);
Arrays.sort(listArr);
// end sort the value list
// get the list of keys
int resultCount = 0;
int index;
if (acsending) {
index = 0;
} else {
index = listArr.length - 1;
}
if (!sharpLimit) {
numKeys = listArr.length;
}
while (true) {
double targetValue = (Double) listArr[index];
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
while (passedMapIterator.hasNext()) {
Entry entry = passedMapIterator.next();
if ((Double) entry.getValue() == targetValue) {
keyList.add(entry.getKey());
passedMapIterator.remove();
resultCount++;
break;
}
}
if (acsending) {
index++;
} else {
index--;
}
if (resultCount >= numKeys) {
break;
}
}
// end get the list of keys
}
return keyList;
}
private double getOrderedValue(double[] array, int index) {
locate(array, 0, array.length - 1, index);
return array[index];
}
// sort the partitions by quick sort, and locate the target index
private void locate(double[] array, int left, int right, int index) {
int mid = (left + right) / 2;
//System.out.println(left+" to "+right+" ("+mid+")");
if (right == left) {
//System.out.println("* "+array[targetIndex]);
//result=array[targetIndex];
return;
}
if (left < right) {
double s = array[mid];
int i = left - 1;
int j = right + 1;
while (true) {
while (array[++i] < s);
while (array[--j] > s);
if (i >= j)
break;
swap(array, i, j);
}
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
if (i > index) {
// the target index in the left partition
//System.out.println("left partition");
locate(array, left, i - 1, index);
} else {
// the target index in the right partition
//System.out.println("right partition");
locate(array, j + 1, right, index);
}
}
}
private void swap(double[] array, int i, int j) {
double t = array[i];
array[i] = array[j];
array[j] = t;
}
}

View File

@ -1,179 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import java.util.*;
import java.util.Map.Entry;
public class MapRankInteger implements MapRank {
private Map map;
private boolean acsending = true;
public MapRankInteger(Map<?, Integer> map, boolean acsending) {
this.map = map;
this.acsending = acsending;
}
public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value
Set mapEntrySet = map.entrySet();
List keyList = new LinkedList();
// if the numKeys is larger than map size, limit it
if (numKeys > map.size()) {
numKeys = map.size();
}
// end if the numKeys is larger than map size, limit it
if (map.size() > 0) {
int[] array = new int[map.size()];
int count = 0;
// get the pass values
Iterator<Entry> mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
array[count++] = (Integer) entry.getValue();
}
// end get the pass values
int targetindex;
if (acsending) {
targetindex = numKeys;
} else {
targetindex = array.length - numKeys;
}
int passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element
// get the passed keys and values
Map passedMap = new HashMap();
List<Integer> valueList = new LinkedList<Integer>();
mapIterator = mapEntrySet.iterator();
while (mapIterator.hasNext()) {
Entry entry = mapIterator.next();
int value = (Integer) entry.getValue();
if ((acsending && value <= passValue) || (!acsending && value >= passValue)) {
passedMap.put(entry.getKey(), value);
valueList.add(value);
}
}
// end get the passed keys and values
// sort the value list
Integer[] listArr = new Integer[valueList.size()];
valueList.toArray(listArr);
Arrays.sort(listArr);
// end sort the value list
// get the list of keys
int resultCount = 0;
int index;
if (acsending) {
index = 0;
} else {
index = listArr.length - 1;
}
if (!sharpLimit) {
numKeys = listArr.length;
}
while (true) {
int targetValue = (Integer) listArr[index];
Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator();
while (passedMapIterator.hasNext()) {
Entry entry = passedMapIterator.next();
if ((Integer) entry.getValue() == targetValue) {
keyList.add(entry.getKey());
passedMapIterator.remove();
resultCount++;
break;
}
}
if (acsending) {
index++;
} else {
index--;
}
if (resultCount >= numKeys) {
break;
}
}
// end get the list of keys
}
return keyList;
}
private int getOrderedValue(int[] array, int index) {
locate(array, 0, array.length - 1, index);
return array[index];
}
// sort the partitions by quick sort, and locate the target index
private void locate(int[] array, int left, int right, int index) {
int mid = (left + right) / 2;
//System.out.println(left+" to "+right+" ("+mid+")");
if (right == left) {
//System.out.println("* "+array[targetIndex]);
//result=array[targetIndex];
return;
}
if (left < right) {
int s = array[mid];
int i = left - 1;
int j = right + 1;
while (true) {
while (array[++i] < s);
while (array[--j] > s);
if (i >= j)
break;
swap(array, i, j);
}
//System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
if (i > index) {
// the target index in the left partition
//System.out.println("left partition");
locate(array, left, i - 1, index);
} else {
// the target index in the right partition
//System.out.println("right partition");
locate(array, j + 1, right, index);
}
}
}
private void swap(int[] array, int i, int j) {
int t = array[i];
array[i] = array[j];
array[j] = t;
}
}

View File

@ -1,230 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
import org.datavec.audio.properties.FingerprintProperties;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
public class PairManager {
FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
private int numFilterBanks = fingerprintProperties.getNumFilterBanks();
private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks;
private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength();
private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval();
private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance();
private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits();
private int maxPairs;
private boolean isReferencePairing;
private HashMap<Integer, Boolean> stopPairTable = new HashMap<>();
/**
* Constructor
*/
public PairManager() {
maxPairs = fingerprintProperties.getRefMaxActivePairs();
isReferencePairing = true;
}
/**
* Constructor, number of pairs of robust points depends on the parameter isReferencePairing
* no. of pairs of reference and sample can be different due to environmental influence of source
* @param isReferencePairing
*/
public PairManager(boolean isReferencePairing) {
if (isReferencePairing) {
maxPairs = fingerprintProperties.getRefMaxActivePairs();
} else {
maxPairs = fingerprintProperties.getSampleMaxActivePairs();
}
this.isReferencePairing = isReferencePairing;
}
/**
* Get a pair-positionList table
* It's a hash map which the key is the hashed pair, and the value is list of positions
* That means the table stores the positions which have the same hashed pair
*
* @param fingerprint fingerprint bytes
* @return pair-positionList HashMap
*/
public HashMap<Integer, List<Integer>> getPair_PositionList_Table(byte[] fingerprint) {
List<int[]> pairPositionList = getPairPositionList(fingerprint);
// table to store pair:pos,pos,pos,...;pair2:pos,pos,pos,....
HashMap<Integer, List<Integer>> pair_positionList_table = new HashMap<>();
// get all pair_positions from list, use a table to collect the data group by pair hashcode
for (int[] pair_position : pairPositionList) {
//System.out.println(pair_position[0]+","+pair_position[1]);
// group by pair-hashcode, i.e.: <pair,List<position>>
if (pair_positionList_table.containsKey(pair_position[0])) {
pair_positionList_table.get(pair_position[0]).add(pair_position[1]);
} else {
List<Integer> positionList = new LinkedList<>();
positionList.add(pair_position[1]);
pair_positionList_table.put(pair_position[0], positionList);
}
// end group by pair-hashcode, i.e.: <pair,List<position>>
}
// end get all pair_positions from list, use a table to collect the data group by pair hashcode
return pair_positionList_table;
}
// this return list contains: int[0]=pair_hashcode, int[1]=position
private List<int[]> getPairPositionList(byte[] fingerprint) {
int numFrames = FingerprintManager.getNumFrames(fingerprint);
// table for paired frames
byte[] pairedFrameTable = new byte[numFrames / anchorPointsIntervalLength + 1]; // each second has numAnchorPointsPerSecond pairs only
// end table for paired frames
List<int[]> pairList = new LinkedList<>();
List<int[]> sortedCoordinateList = getSortedCoordinateList(fingerprint);
for (int[] anchorPoint : sortedCoordinateList) {
int anchorX = anchorPoint[0];
int anchorY = anchorPoint[1];
int numPairs = 0;
for (int[] aSortedCoordinateList : sortedCoordinateList) {
if (numPairs >= maxPairs) {
break;
}
if (isReferencePairing && pairedFrameTable[anchorX
/ anchorPointsIntervalLength] >= numAnchorPointsPerInterval) {
break;
}
int targetX = aSortedCoordinateList[0];
int targetY = aSortedCoordinateList[1];
if (anchorX == targetX && anchorY == targetY) {
continue;
}
// pair up the points
int x1, y1, x2, y2; // x2 always >= x1
if (targetX >= anchorX) {
x2 = targetX;
y2 = targetY;
x1 = anchorX;
y1 = anchorY;
} else {
x2 = anchorX;
y2 = anchorY;
x1 = targetX;
y1 = targetY;
}
// check target zone
if ((x2 - x1) > maxTargetZoneDistance) {
continue;
}
// end check target zone
// check filter bank zone
if (!(y1 / bandwidthPerBank == y2 / bandwidthPerBank)) {
continue; // same filter bank should have equal value
}
// end check filter bank zone
int pairHashcode = (x2 - x1) * numFrequencyUnits * numFrequencyUnits + y2 * numFrequencyUnits + y1;
// stop list applied on sample pairing only
if (!isReferencePairing && stopPairTable.containsKey(pairHashcode)) {
numPairs++; // no reservation
continue; // escape this point only
}
// end stop list applied on sample pairing only
// pass all rules
pairList.add(new int[] {pairHashcode, anchorX});
pairedFrameTable[anchorX / anchorPointsIntervalLength]++;
numPairs++;
// end pair up the points
}
}
return pairList;
}
private List<int[]> getSortedCoordinateList(byte[] fingerprint) {
// each point data is 8 bytes
// first 2 bytes is x
// next 2 bytes is y
// next 4 bytes is intensity
// get all intensities
int numCoordinates = fingerprint.length / 8;
int[] intensities = new int[numCoordinates];
for (int i = 0; i < numCoordinates; i++) {
int pointer = i * 8 + 4;
int intensity = (fingerprint[pointer] & 0xff) << 24 | (fingerprint[pointer + 1] & 0xff) << 16
| (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
intensities[i] = intensity;
}
QuickSortIndexPreserved quicksort = new QuickSortIndexPreserved(intensities);
int[] sortIndexes = quicksort.getSortIndexes();
List<int[]> sortedCoordinateList = new LinkedList<>();
for (int i = sortIndexes.length - 1; i >= 0; i--) {
int pointer = sortIndexes[i] * 8;
int x = (fingerprint[pointer] & 0xff) << 8 | (fingerprint[pointer + 1] & 0xff);
int y = (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff);
sortedCoordinateList.add(new int[] {x, y});
}
return sortedCoordinateList;
}
/**
* Convert hashed pair to bytes
*
* @param pairHashcode hashed pair
* @return byte array
*/
public static byte[] pairHashcodeToBytes(int pairHashcode) {
return new byte[] {(byte) (pairHashcode >> 8), (byte) pairHashcode};
}
/**
* Convert bytes to hased pair
*
* @param pairBytes
* @return hashed pair
*/
public static int pairBytesToHashcode(byte[] pairBytes) {
return (pairBytes[0] & 0xFF) << 8 | (pairBytes[1] & 0xFF);
}
}

View File

@ -1,25 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public abstract class QuickSort {
public abstract int[] getSortIndexes();
}

View File

@ -1,79 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortDouble extends QuickSort {
private int[] indexes;
private double[] array;
public QuickSortDouble(double[] array) {
this.array = array;
indexes = new int[array.length];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
}
public int[] getSortIndexes() {
sort();
return indexes;
}
private void sort() {
quicksort(array, indexes, 0, indexes.length - 1);
}
// quicksort a[left] to a[right]
private void quicksort(double[] a, int[] indexes, int left, int right) {
if (right <= left)
return;
int i = partition(a, indexes, left, right);
quicksort(a, indexes, left, i - 1);
quicksort(a, indexes, i + 1, right);
}
// partition a[left] to a[right], assumes left < right
private int partition(double[] a, int[] indexes, int left, int right) {
int i = left - 1;
int j = right;
while (true) {
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
if (j == left)
break; // don't go out-of-bounds
}
if (i >= j)
break; // check if pointers cross
swap(a, indexes, i, j); // swap two elements into place
}
swap(a, indexes, i, right); // swap with partition element
return i;
}
// exchange a[i] and a[j]
private void swap(double[] a, int[] indexes, int i, int j) {
int swap = indexes[i];
indexes[i] = indexes[j];
indexes[j] = swap;
}
}

View File

@ -1,43 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortIndexPreserved {
private QuickSort quickSort;
public QuickSortIndexPreserved(int[] array) {
quickSort = new QuickSortInteger(array);
}
public QuickSortIndexPreserved(double[] array) {
quickSort = new QuickSortDouble(array);
}
public QuickSortIndexPreserved(short[] array) {
quickSort = new QuickSortShort(array);
}
public int[] getSortIndexes() {
return quickSort.getSortIndexes();
}
}

View File

@ -1,79 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortInteger extends QuickSort {
private int[] indexes;
private int[] array;
public QuickSortInteger(int[] array) {
this.array = array;
indexes = new int[array.length];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
}
public int[] getSortIndexes() {
sort();
return indexes;
}
private void sort() {
quicksort(array, indexes, 0, indexes.length - 1);
}
// quicksort a[left] to a[right]
private void quicksort(int[] a, int[] indexes, int left, int right) {
if (right <= left)
return;
int i = partition(a, indexes, left, right);
quicksort(a, indexes, left, i - 1);
quicksort(a, indexes, i + 1, right);
}
// partition a[left] to a[right], assumes left < right
private int partition(int[] a, int[] indexes, int left, int right) {
int i = left - 1;
int j = right;
while (true) {
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
if (j == left)
break; // don't go out-of-bounds
}
if (i >= j)
break; // check if pointers cross
swap(a, indexes, i, j); // swap two elements into place
}
swap(a, indexes, i, right); // swap with partition element
return i;
}
// exchange a[i] and a[j]
private void swap(int[] a, int[] indexes, int i, int j) {
int swap = indexes[i];
indexes[i] = indexes[j];
indexes[j] = swap;
}
}

View File

@ -1,79 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.fingerprint;
public class QuickSortShort extends QuickSort {
private int[] indexes;
private short[] array;
public QuickSortShort(short[] array) {
this.array = array;
indexes = new int[array.length];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
}
public int[] getSortIndexes() {
sort();
return indexes;
}
private void sort() {
quicksort(array, indexes, 0, indexes.length - 1);
}
// quicksort a[left] to a[right]
private void quicksort(short[] a, int[] indexes, int left, int right) {
if (right <= left)
return;
int i = partition(a, indexes, left, right);
quicksort(a, indexes, left, i - 1);
quicksort(a, indexes, i + 1, right);
}
// partition a[left] to a[right], assumes left < right
private int partition(short[] a, int[] indexes, int left, int right) {
int i = left - 1;
int j = right;
while (true) {
while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel
while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap
if (j == left)
break; // don't go out-of-bounds
}
if (i >= j)
break; // check if pointers cross
swap(a, indexes, i, j); // swap two elements into place
}
swap(a, indexes, i, right); // swap with partition element
return i;
}
// exchange a[i] and a[j]
private void swap(short[] a, int[] indexes, int i, int j) {
int swap = indexes[i];
indexes[i] = indexes[j];
indexes[j] = swap;
}
}

View File

@ -1,51 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.formats.input;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.BaseInputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.audio.recordreader.WavFileRecordReader;
import java.io.IOException;
/**
*
* Wave file input format
*
* @author Adam Gibson
*/
public class WavInputFormat extends BaseInputFormat {
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
return createReader(split);
}
@Override
public RecordReader createReader(InputSplit split) throws IOException, InterruptedException {
RecordReader waveRecordReader = new WavFileRecordReader();
waveRecordReader.initialize(split);
return waveRecordReader;
}
}

View File

@ -1,36 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.formats.output;
import org.datavec.api.conf.Configuration;
import org.datavec.api.exceptions.DataVecException;
import org.datavec.api.formats.output.OutputFormat;
import org.datavec.api.records.writer.RecordWriter;
/**
* @author Adam Gibson
*/
public class WaveOutputFormat implements OutputFormat {
@Override
public RecordWriter createWriter(Configuration conf) throws DataVecException {
return null;
}
}

View File

@ -1,139 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
public class ArrayRankDouble {
/**
* Get the index position of maximum value the given array
* @param array an array
* @return index of the max value in array
*/
public int getMaxValueIndex(double[] array) {
int index = 0;
double max = Integer.MIN_VALUE;
for (int i = 0; i < array.length; i++) {
if (array[i] > max) {
max = array[i];
index = i;
}
}
return index;
}
/**
* Get the index position of minimum value in the given array
* @param array an array
* @return index of the min value in array
*/
public int getMinValueIndex(double[] array) {
int index = 0;
double min = Integer.MAX_VALUE;
for (int i = 0; i < array.length; i++) {
if (array[i] < min) {
min = array[i];
index = i;
}
}
return index;
}
/**
* Get the n-th value in the array after sorted
* @param array an array
* @param n position in array
* @param ascending is ascending order or not
* @return value at nth position of array
*/
public double getNthOrderedValue(double[] array, int n, boolean ascending) {
if (n > array.length) {
n = array.length;
}
int targetindex;
if (ascending) {
targetindex = n;
} else {
targetindex = array.length - n;
}
// this value is the value of the numKey-th element
return getOrderedValue(array, targetindex);
}
private double getOrderedValue(double[] array, int index) {
locate(array, 0, array.length - 1, index);
return array[index];
}
// sort the partitions by quick sort, and locate the target index
private void locate(double[] array, int left, int right, int index) {
int mid = (left + right) / 2;
// System.out.println(left+" to "+right+" ("+mid+")");
if (right == left) {
// System.out.println("* "+array[targetIndex]);
// result=array[targetIndex];
return;
}
if (left < right) {
double s = array[mid];
int i = left - 1;
int j = right + 1;
while (true) {
while (array[++i] < s);
while (array[--j] > s);
if (i >= j)
break;
swap(array, i, j);
}
// System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right);
if (i > index) {
// the target index in the left partition
// System.out.println("left partition");
locate(array, left, i - 1, index);
} else {
// the target index in the right partition
// System.out.println("right partition");
locate(array, j + 1, right, index);
}
}
}
private void swap(double[] array, int i, int j) {
double t = array[i];
array[i] = array[j];
array[j] = t;
}
}

View File

@ -1,28 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
public interface IntensityProcessor {
public void execute();
public double[][] getIntensities();
}

View File

@ -1,48 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
import java.util.LinkedList;
import java.util.List;
public class ProcessorChain {
private double[][] intensities;
List<IntensityProcessor> processorList = new LinkedList<IntensityProcessor>();
public ProcessorChain(double[][] intensities) {
this.intensities = intensities;
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, 1);
processorList.add(robustProcessor);
process();
}
private void process() {
for (IntensityProcessor processor : processorList) {
processor.execute();
intensities = processor.getIntensities();
}
}
public double[][] getIntensities() {
return intensities;
}
}

View File

@ -1,61 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
public class RobustIntensityProcessor implements IntensityProcessor {
private double[][] intensities;
private int numPointsPerFrame;
public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) {
this.intensities = intensities;
this.numPointsPerFrame = numPointsPerFrame;
}
public void execute() {
int numX = intensities.length;
int numY = intensities[0].length;
double[][] processedIntensities = new double[numX][numY];
for (int i = 0; i < numX; i++) {
double[] tmpArray = new double[numY];
System.arraycopy(intensities[i], 0, tmpArray, 0, numY);
// pass value is the last some elements in sorted array
ArrayRankDouble arrayRankDouble = new ArrayRankDouble();
double passValue = arrayRankDouble.getNthOrderedValue(tmpArray, numPointsPerFrame, false);
// only passed elements will be assigned a value
for (int j = 0; j < numY; j++) {
if (intensities[i][j] >= passValue) {
processedIntensities[i][j] = intensities[i][j];
}
}
}
intensities = processedIntensities;
}
public double[][] getIntensities() {
return intensities;
}
}

View File

@ -1,49 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.processor;
import java.util.LinkedList;
import java.util.List;
public class TopManyPointsProcessorChain {
private double[][] intensities;
List<IntensityProcessor> processorList = new LinkedList<>();
public TopManyPointsProcessorChain(double[][] intensities, int numPoints) {
this.intensities = intensities;
RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, numPoints);
processorList.add(robustProcessor);
process();
}
private void process() {
for (IntensityProcessor processor : processorList) {
processor.execute();
intensities = processor.getIntensities();
}
}
public double[][] getIntensities() {
return intensities;
}
}

View File

@ -1,121 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.properties;
public class FingerprintProperties {
protected static FingerprintProperties instance = null;
private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint
private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size
private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ... 32
private int numFilterBanks = 4;
private int upperBoundedFrequency = 1500; // low pass
private int lowerBoundedFrequency = 400; // high pass
private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps)
private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps
private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually
private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs
private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip
private int numAnchorPointsPerInterval = 10;
private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second)
private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second)
private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units
public static FingerprintProperties getInstance() {
if (instance == null) {
synchronized (FingerprintProperties.class) {
if (instance == null) {
instance = new FingerprintProperties();
}
}
}
return instance;
}
public int getNumRobustPointsPerFrame() {
return numRobustPointsPerFrame;
}
public int getSampleSizePerFrame() {
return sampleSizePerFrame;
}
public int getOverlapFactor() {
return overlapFactor;
}
public int getNumFilterBanks() {
return numFilterBanks;
}
public int getUpperBoundedFrequency() {
return upperBoundedFrequency;
}
public int getLowerBoundedFrequency() {
return lowerBoundedFrequency;
}
public int getFps() {
return fps;
}
public int getRefMaxActivePairs() {
return refMaxActivePairs;
}
public int getSampleMaxActivePairs() {
return sampleMaxActivePairs;
}
public int getNumAnchorPointsPerInterval() {
return numAnchorPointsPerInterval;
}
public int getAnchorPointsIntervalLength() {
return anchorPointsIntervalLength;
}
public int getMaxTargetZoneDistance() {
return maxTargetZoneDistance;
}
public int getNumFrequencyUnits() {
return numFrequencyUnits;
}
public int getMaxPossiblePairHashcode() {
return maxTargetZoneDistance * numFrequencyUnits * numFrequencyUnits + numFrequencyUnits * numFrequencyUnits
+ numFrequencyUnits;
}
public int getSampleRate() {
return sampleRate;
}
public int getNumFramesInOneSecond() {
return numFramesInOneSecond;
}
}

View File

@ -1,225 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.recordreader;
import org.apache.commons.io.FileUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.split.BaseInputSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
/**
* Base audio file loader
* @author Adam Gibson
*/
public abstract class BaseAudioRecordReader extends BaseRecordReader {
private Iterator<File> iter;
private List<Writable> record;
private boolean hitImage = false;
private boolean appendLabel = false;
private List<String> labels = new ArrayList<>();
private Configuration conf;
protected InputSplit inputSplit;
public BaseAudioRecordReader() {}
public BaseAudioRecordReader(boolean appendLabel, List<String> labels) {
this.appendLabel = appendLabel;
this.labels = labels;
}
public BaseAudioRecordReader(List<String> labels) {
this.labels = labels;
}
public BaseAudioRecordReader(boolean appendLabel) {
this.appendLabel = appendLabel;
}
protected abstract List<Writable> loadData(File file, InputStream inputStream) throws IOException;
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
inputSplit = split;
if (split instanceof BaseInputSplit) {
URI[] locations = split.locations();
if (locations != null && locations.length >= 1) {
if (locations.length > 1) {
List<File> allFiles = new ArrayList<>();
for (URI location : locations) {
File iter = new File(location);
if (iter.isDirectory()) {
Iterator<File> allFiles2 = FileUtils.iterateFiles(iter, null, true);
while (allFiles2.hasNext())
allFiles.add(allFiles2.next());
}
else
allFiles.add(iter);
}
iter = allFiles.iterator();
} else {
File curr = new File(locations[0]);
if (curr.isDirectory())
iter = FileUtils.iterateFiles(curr, null, true);
else
iter = Collections.singletonList(curr).iterator();
}
}
}
else if (split instanceof InputStreamInputSplit) {
record = new ArrayList<>();
InputStreamInputSplit split2 = (InputStreamInputSplit) split;
InputStream is = split2.getIs();
URI[] locations = split2.locations();
if (appendLabel) {
Path path = Paths.get(locations[0]);
String parent = path.getParent().toString();
record.add(new DoubleWritable(labels.indexOf(parent)));
}
is.close();
}
}
@Override
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
this.conf = conf;
this.appendLabel = conf.getBoolean(APPEND_LABEL, false);
this.labels = new ArrayList<>(conf.getStringCollection(LABELS));
initialize(split);
}
@Override
public List<Writable> next() {
if (iter != null) {
File next = iter.next();
invokeListeners(next);
try {
return loadData(next, null);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else if (record != null) {
hitImage = true;
return record;
}
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
@Override
public boolean hasNext() {
if (iter != null) {
return iter.hasNext();
} else if (record != null) {
return !hitImage;
}
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
@Override
public void close() throws IOException {
}
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public Configuration getConf() {
return conf;
}
@Override
public List<String> getLabels() {
return null;
}
@Override
public void reset() {
if (inputSplit == null)
throw new UnsupportedOperationException("Cannot reset without first initializing");
try {
initialize(inputSplit);
} catch (Exception e) {
throw new RuntimeException("Error during LineRecordReader reset", e);
}
}
@Override
public boolean resetSupported(){
if(inputSplit == null){
return false;
}
return inputSplit.resetSupported();
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
invokeListeners(uri);
try {
return loadData(null, dataInputStream);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Record nextRecord() {
return new org.datavec.api.records.impl.Record(next(), null);
}
@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
}
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
throw new UnsupportedOperationException("Loading from metadata not yet implemented");
}
}

View File

@ -1,76 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.recordreader;
import org.bytedeco.javacv.FFmpegFrameGrabber;
import org.bytedeco.javacv.Frame;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.Writable;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
import static org.bytedeco.ffmpeg.global.avutil.AV_SAMPLE_FMT_FLT;
/**
* Native audio file loader using FFmpeg.
*
* @author saudet
*/
public class NativeAudioRecordReader extends BaseAudioRecordReader {
public NativeAudioRecordReader() {}
public NativeAudioRecordReader(boolean appendLabel, List<String> labels) {
super(appendLabel, labels);
}
public NativeAudioRecordReader(List<String> labels) {
super(labels);
}
public NativeAudioRecordReader(boolean appendLabel) {
super(appendLabel);
}
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
List<Writable> ret = new ArrayList<>();
try (FFmpegFrameGrabber grabber = inputStream != null ? new FFmpegFrameGrabber(inputStream)
: new FFmpegFrameGrabber(file.getAbsolutePath())) {
grabber.setSampleFormat(AV_SAMPLE_FMT_FLT);
grabber.start();
Frame frame;
while ((frame = grabber.grab()) != null) {
while (frame.samples != null && frame.samples[0].hasRemaining()) {
for (int i = 0; i < frame.samples.length; i++) {
ret.add(new FloatWritable(((FloatBuffer) frame.samples[i]).get()));
}
}
}
}
return ret;
}
}

View File

@ -1,57 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio.recordreader;
import org.datavec.api.util.RecordUtils;
import org.datavec.api.writable.Writable;
import org.datavec.audio.Wave;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
/**
* Wav file loader
* @author Adam Gibson
*/
public class WavFileRecordReader extends BaseAudioRecordReader {
public WavFileRecordReader() {}
public WavFileRecordReader(boolean appendLabel, List<String> labels) {
super(appendLabel, labels);
}
public WavFileRecordReader(List<String> labels) {
super(labels);
}
public WavFileRecordReader(boolean appendLabel) {
super(appendLabel);
}
protected List<Writable> loadData(File file, InputStream inputStream) throws IOException {
Wave wave = inputStream != null ? new Wave(inputStream) : new Wave(file.getAbsolutePath());
return RecordUtils.toRecord(wave.getNormalizedAmplitudes());
}
}

View File

@ -1,51 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
public long getTimeoutMilliseconds() {
return 60000;
}
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.audio";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,68 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import org.bytedeco.javacv.FFmpegFrameRecorder;
import org.bytedeco.javacv.Frame;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.datavec.audio.recordreader.NativeAudioRecordReader;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.ShortBuffer;
import java.util.List;
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_VORBIS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author saudet
*/
public class AudioReaderTest extends BaseND4JTest {
@Ignore
@Test
public void testNativeAudioReader() throws Exception {
File tempFile = File.createTempFile("testNativeAudioReader", ".ogg");
FFmpegFrameRecorder recorder = new FFmpegFrameRecorder(tempFile, 2);
recorder.setAudioCodec(AV_CODEC_ID_VORBIS);
recorder.setSampleRate(44100);
recorder.start();
Frame audioFrame = new Frame();
ShortBuffer audioBuffer = ShortBuffer.allocate(64 * 1024);
audioFrame.sampleRate = 44100;
audioFrame.audioChannels = 2;
audioFrame.samples = new ShortBuffer[] {audioBuffer};
recorder.record(audioFrame);
recorder.stop();
recorder.release();
RecordReader reader = new NativeAudioRecordReader();
reader.initialize(new FileSplit(tempFile));
assertTrue(reader.hasNext());
List<Writable> record = reader.next();
assertEquals(audioBuffer.limit(), record.size());
}
}

View File

@ -1,69 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.audio;
import org.datavec.audio.dsp.FastFourierTransform;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.common.tests.BaseND4JTest;
public class TestFastFourierTransform extends BaseND4JTest {
@Test
public void testFastFourierTransformComplex() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
double[] frequencies = fft.getMagnitudes(amplitudes);
Assert.assertEquals(2, frequencies.length);
Assert.assertArrayEquals(new double[] {21.335, 18.513}, frequencies, 0.005);
}
@Test
public void testFastFourierTransformComplexLong() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
double[] frequencies = fft.getMagnitudes(amplitudes, true);
Assert.assertEquals(4, frequencies.length);
Assert.assertArrayEquals(new double[] {21.335, 18.5132, 14.927, 7.527}, frequencies, 0.005);
}
@Test
public void testFastFourierTransformReal() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6};
double[] frequencies = fft.getMagnitudes(amplitudes, false);
Assert.assertEquals(4, frequencies.length);
Assert.assertArrayEquals(new double[] {28.8, 2.107, 14.927, 19.874}, frequencies, 0.005);
}
@Test
public void testFastFourierTransformRealOddSize() {
FastFourierTransform fft = new FastFourierTransform();
double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5};
double[] frequencies = fft.getMagnitudes(amplitudes, false);
Assert.assertEquals(3, frequencies.length);
Assert.assertArrayEquals(new double[] {24.2, 3.861, 16.876}, frequencies, 0.005);
}
}

View File

@ -1,71 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-data</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-data-codec</artifactId>
<name>datavec-data-codec</name>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.jcodec</groupId>
<artifactId>jcodec</artifactId>
<version>0.1.5</version>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!--
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>ffmpeg-platform</artifactId>
<version>${ffmpeg.version}-${javacpp-presets.version}</version>
</dependency>
-->
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,41 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.format.input;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.BaseInputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.codec.reader.CodecRecordReader;
import java.io.IOException;
/**
* @author Adam Gibson
*/
public class CodecInputFormat extends BaseInputFormat {
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
RecordReader reader = new CodecRecordReader();
reader.initialize(conf, split);
return reader;
}
}

View File

@ -1,144 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public abstract class BaseCodecRecordReader extends FileRecordReader implements SequenceRecordReader {
protected int startFrame = 0;
protected int numFrames = -1;
protected int totalFrames = -1;
protected double framesPerSecond = -1;
protected double videoLength = -1;
protected int rows = 28, cols = 28;
protected boolean ravel = false;
public final static String NAME_SPACE = "org.datavec.codec.reader";
public final static String ROWS = NAME_SPACE + ".rows";
public final static String COLUMNS = NAME_SPACE + ".columns";
public final static String START_FRAME = NAME_SPACE + ".startframe";
public final static String TOTAL_FRAMES = NAME_SPACE + ".frames";
public final static String TIME_SLICE = NAME_SPACE + ".time";
public final static String RAVEL = NAME_SPACE + ".ravel";
public final static String VIDEO_DURATION = NAME_SPACE + ".duration";
@Override
public List<List<Writable>> sequenceRecord() {
URI next = locationsIterator.next();
try (InputStream s = streamCreatorFn.apply(next)){
return loadData(null, s);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException {
return loadData(null, dataInputStream);
}
protected abstract List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException;
@Override
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
setConf(conf);
initialize(split);
}
@Override
public List<Writable> next() {
throw new UnsupportedOperationException("next() not supported for CodecRecordReader (use: sequenceRecord)");
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
throw new UnsupportedOperationException("record(URI,DataInputStream) not supported for CodecRecordReader");
}
@Override
public void setConf(Configuration conf) {
super.setConf(conf);
startFrame = conf.getInt(START_FRAME, 0);
numFrames = conf.getInt(TOTAL_FRAMES, -1);
rows = conf.getInt(ROWS, 28);
cols = conf.getInt(COLUMNS, 28);
framesPerSecond = conf.getFloat(TIME_SLICE, -1);
videoLength = conf.getFloat(VIDEO_DURATION, -1);
ravel = conf.getBoolean(RAVEL, false);
totalFrames = conf.getInt(TOTAL_FRAMES, -1);
}
@Override
public Configuration getConf() {
return super.getConf();
}
@Override
public SequenceRecord nextSequence() {
URI next = locationsIterator.next();
List<List<Writable>> list;
try (InputStream s = streamCreatorFn.apply(next)){
list = loadData(null, s);
} catch (IOException e) {
throw new RuntimeException(e);
}
return new org.datavec.api.records.impl.SequenceRecord(list,
new RecordMetaDataURI(next, CodecRecordReader.class));
}
@Override
public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException {
return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0);
}
@Override
public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
List<SequenceRecord> out = new ArrayList<>();
for (RecordMetaData meta : recordMetaDatas) {
try (InputStream s = streamCreatorFn.apply(meta.getURI())){
List<List<Writable>> list = loadData(null, s);
out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta));
}
}
return out;
}
}

View File

@ -1,138 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.apache.commons.compress.utils.IOUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.ImageLoader;
import org.jcodec.api.FrameGrab;
import org.jcodec.api.JCodecException;
import org.jcodec.common.ByteBufferSeekableByteChannel;
import org.jcodec.common.NIOUtils;
import org.jcodec.common.SeekableByteChannel;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class CodecRecordReader extends BaseCodecRecordReader {
private ImageLoader imageLoader;
@Override
public void setConf(Configuration conf) {
super.setConf(conf);
imageLoader = new ImageLoader(rows, cols);
}
@Override
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
SeekableByteChannel seekableByteChannel;
if (inputStream != null) {
//Reading video from DataInputStream: Need data from this stream in a SeekableByteChannel
//Approach used here: load entire video into memory -> ByteBufferSeekableByteChanel
byte[] data = IOUtils.toByteArray(inputStream);
ByteBuffer bb = ByteBuffer.wrap(data);
seekableByteChannel = new FixedByteBufferSeekableByteChannel(bb);
} else {
seekableByteChannel = NIOUtils.readableFileChannel(file);
}
List<List<Writable>> record = new ArrayList<>();
if (numFrames >= 1) {
FrameGrab fg;
try {
fg = new FrameGrab(seekableByteChannel);
if (startFrame != 0)
fg.seekToFramePrecise(startFrame);
} catch (JCodecException e) {
throw new RuntimeException(e);
}
for (int i = startFrame; i < startFrame + numFrames; i++) {
try {
BufferedImage grab = fg.getFrame();
if (ravel)
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
else
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
} else {
if (framesPerSecond < 1)
throw new IllegalStateException("No frames or frame time intervals specified");
else {
for (double i = 0; i < videoLength; i += framesPerSecond) {
try {
BufferedImage grab = FrameGrab.getFrame(seekableByteChannel, i);
if (ravel)
record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab)));
else
record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab)));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
return record;
}
/** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */
private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel {
private ByteBuffer backing;
public FixedByteBufferSeekableByteChannel(ByteBuffer backing) {
super(backing);
try {
Field f = this.getClass().getSuperclass().getDeclaredField("maxPos");
f.setAccessible(true);
f.set(this, backing.limit());
} catch (Exception e) {
throw new RuntimeException(e);
}
this.backing = backing;
}
@Override
public int read(ByteBuffer dst) throws IOException {
if (!backing.hasRemaining())
return -1;
return super.read(dst);
}
}
}

View File

@ -1,82 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.bytedeco.javacv.FFmpegFrameGrabber;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.conf.Configuration;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
public class NativeCodecRecordReader extends BaseCodecRecordReader {
private OpenCVFrameConverter.ToMat converter;
private NativeImageLoader imageLoader;
@Override
public void setConf(Configuration conf) {
super.setConf(conf);
converter = new OpenCVFrameConverter.ToMat();
imageLoader = new NativeImageLoader(rows, cols);
}
@Override
protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException {
List<List<Writable>> record = new ArrayList<>();
try (FFmpegFrameGrabber fg =
inputStream != null ? new FFmpegFrameGrabber(inputStream) : new FFmpegFrameGrabber(file)) {
if (numFrames >= 1) {
fg.start();
if (startFrame != 0)
fg.setFrameNumber(startFrame);
for (int i = startFrame; i < startFrame + numFrames; i++) {
Frame grab = fg.grabImage();
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
}
} else {
if (framesPerSecond < 1)
throw new IllegalStateException("No frames or frame time intervals specified");
else {
fg.start();
for (double i = 0; i < videoLength; i += framesPerSecond) {
fg.setTimestamp(Math.round(i * 1000000L));
Frame grab = fg.grabImage();
record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab))));
}
}
}
}
return record;
}
}

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.codec.reader";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,212 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.codec.reader;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.ArrayWritable;
import org.datavec.api.writable.Writable;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Iterator;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author Adam Gibson
*/
public class CodecReaderTest {
@Test
public void testCodecReader() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new CodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
// System.out.println(record.size());
Iterator<List<Writable>> it = record.iterator();
List<Writable> first = it.next();
// System.out.println(first);
//Expected size: 80x46x3
assertEquals(1, first.size());
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
}
@Test
public void testCodecReaderMeta() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new CodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
assertEquals(500, record.size()); //500 frames
reader.reset();
SequenceRecord seqR = reader.nextSequence();
assertEquals(record, seqR.getSequenceRecord());
RecordMetaData meta = seqR.getMetaData();
// System.out.println(meta);
assertTrue(meta.getURI().toString().endsWith(file.getName()));
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
assertEquals(seqR, fromMeta);
}
@Test
public void testViaDataInputStream() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new CodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
Configuration conf2 = new Configuration(conf);
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> expected = reader.sequenceRecord();
SequenceRecordReader reader2 = new CodecRecordReader();
reader2.setConf(conf2);
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
assertEquals(expected, actual);
}
@Ignore
@Test
public void testNativeCodecReader() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new NativeCodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
// System.out.println(record.size());
Iterator<List<Writable>> it = record.iterator();
List<Writable> first = it.next();
// System.out.println(first);
//Expected size: 80x46x3
assertEquals(1, first.size());
assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
}
@Ignore
@Test
public void testNativeCodecReaderMeta() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new NativeCodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> record = reader.sequenceRecord();
assertEquals(500, record.size()); //500 frames
reader.reset();
SequenceRecord seqR = reader.nextSequence();
assertEquals(record, seqR.getSequenceRecord());
RecordMetaData meta = seqR.getMetaData();
// System.out.println(meta);
assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4"));
SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
assertEquals(seqR, fromMeta);
}
@Ignore
@Test
public void testNativeViaDataInputStream() throws Exception {
File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
SequenceRecordReader reader = new NativeCodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "160");
conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
conf.set(CodecRecordReader.ROWS, "80");
conf.set(CodecRecordReader.COLUMNS, "46");
Configuration conf2 = new Configuration(conf);
reader.initialize(new FileSplit(file));
reader.setConf(conf);
assertTrue(reader.hasNext());
List<List<Writable>> expected = reader.sequenceRecord();
SequenceRecordReader reader2 = new NativeCodecRecordReader();
reader2.setConf(conf2);
DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);
assertEquals(expected, actual);
}
}

View File

@ -1,77 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-data</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-data-nlp</artifactId>
<name>datavec-data-nlp</name>
<properties>
<cleartk.version>2.0.0</cleartk.version>
</properties>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.cleartk</groupId>
<artifactId>cleartk-snowball</artifactId>
<version>${cleartk.version}</version>
</dependency>
<dependency>
<groupId>org.cleartk</groupId>
<artifactId>cleartk-opennlp-tools</artifactId>
<version>${cleartk.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,237 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSTaggerME;
import opennlp.uima.postag.POSModelResource;
import opennlp.uima.postag.POSModelResourceImpl;
import opennlp.uima.util.AnnotationComboIterator;
import opennlp.uima.util.AnnotationIteratorPair;
import opennlp.uima.util.AnnotatorUtil;
import opennlp.uima.util.UimaUtil;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.Feature;
import org.apache.uima.cas.Type;
import org.apache.uima.cas.TypeSystem;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.component.CasAnnotator_ImplBase;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.ExternalResourceFactory;
import org.apache.uima.resource.ResourceAccessException;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.Level;
import org.apache.uima.util.Logger;
import org.cleartk.token.type.Sentence;
import org.cleartk.token.type.Token;
import org.datavec.nlp.movingwindow.Util;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
public class PoStagger extends CasAnnotator_ImplBase {
static {
//UIMA logging
Util.disableLogging();
}
private POSTaggerME posTagger;
private Type sentenceType;
private Type tokenType;
private Feature posFeature;
private Feature probabilityFeature;
private UimaContext context;
private Logger logger;
/**
* Initializes a new instance.
*
* Note: Use {@link #initialize(org.apache.uima.UimaContext) } to initialize this instance. Not use the
* constructor.
*/
public PoStagger() {
// must not be implemented !
}
/**
* Initializes the current instance with the given context.
*
* Note: Do all initialization in this method, do not use the constructor.
*/
@Override
public void initialize(UimaContext context) throws ResourceInitializationException {
super.initialize(context);
this.context = context;
this.logger = context.getLogger();
if (this.logger.isLoggable(Level.INFO)) {
this.logger.log(Level.INFO, "Initializing the OpenNLP " + "Part of Speech annotator.");
}
POSModel model;
try {
POSModelResource modelResource = (POSModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
model = modelResource.getModel();
} catch (ResourceAccessException e) {
throw new ResourceInitializationException(e);
}
Integer beamSize = AnnotatorUtil.getOptionalIntegerParameter(context, UimaUtil.BEAM_SIZE_PARAMETER);
if (beamSize == null)
beamSize = POSTaggerME.DEFAULT_BEAM_SIZE;
this.posTagger = new POSTaggerME(model, beamSize, 0);
}
/**
* Initializes the type system.
*/
@Override
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
// sentence type
this.sentenceType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
UimaUtil.SENTENCE_TYPE_PARAMETER);
// token type
this.tokenType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem,
UimaUtil.TOKEN_TYPE_PARAMETER);
// pos feature
this.posFeature = AnnotatorUtil.getRequiredFeatureParameter(this.context, this.tokenType,
UimaUtil.POS_FEATURE_PARAMETER, CAS.TYPE_NAME_STRING);
this.probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(this.context, this.tokenType,
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
}
/**
* Performs pos-tagging on the given tcas object.
*/
@Override
public synchronized void process(CAS tcas) {
final AnnotationComboIterator comboIterator =
new AnnotationComboIterator(tcas, this.sentenceType, this.tokenType);
for (AnnotationIteratorPair annotationIteratorPair : comboIterator) {
final List<AnnotationFS> sentenceTokenAnnotationList = new LinkedList<AnnotationFS>();
final List<String> sentenceTokenList = new LinkedList<String>();
for (AnnotationFS tokenAnnotation : annotationIteratorPair.getSubIterator()) {
sentenceTokenAnnotationList.add(tokenAnnotation);
sentenceTokenList.add(tokenAnnotation.getCoveredText());
}
final List<String> posTags = this.posTagger.tag(sentenceTokenList);
double posProbabilities[] = null;
if (this.probabilityFeature != null) {
posProbabilities = this.posTagger.probs();
}
final Iterator<String> posTagIterator = posTags.iterator();
final Iterator<AnnotationFS> sentenceTokenIterator = sentenceTokenAnnotationList.iterator();
int index = 0;
while (posTagIterator.hasNext() && sentenceTokenIterator.hasNext()) {
final String posTag = posTagIterator.next();
final AnnotationFS tokenAnnotation = sentenceTokenIterator.next();
tokenAnnotation.setStringValue(this.posFeature, posTag);
if (posProbabilities != null) {
tokenAnnotation.setDoubleValue(this.posFeature, posProbabilities[index]);
}
index++;
}
// log tokens with pos
if (this.logger.isLoggable(Level.FINER)) {
final StringBuilder sentenceWithPos = new StringBuilder();
sentenceWithPos.append("\"");
for (final Iterator<AnnotationFS> it = sentenceTokenAnnotationList.iterator(); it.hasNext();) {
final AnnotationFS token = it.next();
sentenceWithPos.append(token.getCoveredText());
sentenceWithPos.append('\\');
sentenceWithPos.append(token.getStringValue(this.posFeature));
sentenceWithPos.append(' ');
}
// delete last whitespace
if (sentenceWithPos.length() > 1) // not 0 because it contains already the " char
sentenceWithPos.setLength(sentenceWithPos.length() - 1);
sentenceWithPos.append("\"");
this.logger.log(Level.FINER, sentenceWithPos.toString());
}
}
}
/**
* Releases allocated resources.
*/
@Override
public void destroy() {
this.posTagger = null;
}
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
String modelPath = String.format("/models/%s-pos-maxent.bin", languageCode);
return AnalysisEngineFactory.createEngineDescription(PoStagger.class, UimaUtil.MODEL_PARAMETER,
ExternalResourceFactory.createExternalResourceDescription(POSModelResourceImpl.class,
PoStagger.class.getResource(modelPath).toString()),
UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), UimaUtil.TOKEN_TYPE_PARAMETER,
Token.class.getName(), UimaUtil.POS_FEATURE_PARAMETER, "pos");
}
}

View File

@ -1,52 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.util.ParamUtil;
import org.datavec.nlp.movingwindow.Util;
public class SentenceAnnotator extends org.cleartk.opennlp.tools.SentenceAnnotator {
static {
//UIMA logging
Util.disableLogging();
}
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
return AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.class, PARAM_SENTENCE_MODEL_PATH,
ParamUtil.getParameterValue(PARAM_SENTENCE_MODEL_PATH, "/models/en-sent.bin"),
PARAM_WINDOW_CLASS_NAMES, ParamUtil.getParameterValue(PARAM_WINDOW_CLASS_NAMES, null));
}
@Override
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
super.process(jCas);
}
}

View File

@ -1,58 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.snowball.SnowballStemmer;
import org.cleartk.token.type.Token;
public class StemmerAnnotator extends SnowballStemmer<Token> {
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
return getDescription("English");
}
public static AnalysisEngineDescription getDescription(String language) throws ResourceInitializationException {
return AnalysisEngineFactory.createEngineDescription(StemmerAnnotator.class, SnowballStemmer.PARAM_STEMMER_NAME,
language);
}
@SuppressWarnings("unchecked")
@Override
public synchronized void process(JCas jCas) throws AnalysisEngineProcessException {
super.process(jCas);
}
@Override
public void setStem(Token token, String stem) {
token.setStem(stem);
}
}

View File

@ -1,70 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.annotator;
import opennlp.uima.tokenize.TokenizerModelResourceImpl;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.ExternalResourceFactory;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.opennlp.tools.Tokenizer;
import org.cleartk.token.type.Sentence;
import org.cleartk.token.type.Token;
import org.datavec.nlp.movingwindow.Util;
import org.datavec.nlp.tokenization.tokenizer.ConcurrentTokenizer;
/**
* Overrides OpenNLP tokenizer to be thread safe
*/
public class TokenizerAnnotator extends Tokenizer {
static {
//UIMA logging
Util.disableLogging();
}
public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException {
String modelPath = String.format("/models/%s-token.bin", languageCode);
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
ConcurrentTokenizer.class.getResource(modelPath).toString()),
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
}
public static AnalysisEngineDescription getDescription() throws ResourceInitializationException {
String modelPath = String.format("/models/%s-token.bin", "en");
return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class,
opennlp.uima.util.UimaUtil.MODEL_PARAMETER,
ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class,
ConcurrentTokenizer.class.getResource(modelPath).toString()),
opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(),
opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName());
}
}

View File

@ -1,41 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.input;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.BaseInputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.nlp.reader.TfidfRecordReader;
import java.io.IOException;
/**
* @author Adam Gibson
*/
public class TextInputFormat extends BaseInputFormat {
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
RecordReader reader = new TfidfRecordReader();
reader.initialize(conf, split);
return reader;
}
}

View File

@ -1,148 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.metadata;
import org.nd4j.common.primitives.Counter;
import org.datavec.api.conf.Configuration;
import org.datavec.nlp.vectorizer.TextVectorizer;
import org.nd4j.common.util.MathUtils;
import org.nd4j.common.util.Index;
public class DefaultVocabCache implements VocabCache {
private Counter<String> wordFrequencies = new Counter<>();
private Counter<String> docFrequencies = new Counter<>();
private int minWordFrequency;
private Index vocabWords = new Index();
private double numDocs = 0;
/**
* Instantiate with a given min word frequency
* @param minWordFrequency
*/
public DefaultVocabCache(int minWordFrequency) {
this.minWordFrequency = minWordFrequency;
}
/*
* Constructor for use with initialize()
*/
public DefaultVocabCache() {
}
@Override
public void incrementNumDocs(double by) {
numDocs += by;
}
@Override
public double numDocs() {
return numDocs;
}
@Override
public String wordAt(int i) {
return vocabWords.get(i).toString();
}
@Override
public int wordIndex(String word) {
return vocabWords.indexOf(word);
}
@Override
public void initialize(Configuration conf) {
minWordFrequency = conf.getInt(TextVectorizer.MIN_WORD_FREQUENCY, 5);
}
@Override
public double wordFrequency(String word) {
return wordFrequencies.getCount(word);
}
@Override
public int minWordFrequency() {
return minWordFrequency;
}
@Override
public Index vocabWords() {
return vocabWords;
}
@Override
public void incrementDocCount(String word) {
incrementDocCount(word, 1.0);
}
@Override
public void incrementDocCount(String word, double by) {
docFrequencies.incrementCount(word, by);
}
@Override
public void incrementCount(String word) {
incrementCount(word, 1.0);
}
@Override
public void incrementCount(String word, double by) {
wordFrequencies.incrementCount(word, by);
if (wordFrequencies.getCount(word) >= minWordFrequency && vocabWords.indexOf(word) < 0)
vocabWords.add(word);
}
@Override
public double idf(String word) {
return docFrequencies.getCount(word);
}
@Override
public double tfidf(String word, double frequency, boolean smoothIdf) {
double tf = tf((int) frequency);
double docFreq = docFrequencies.getCount(word);
double idf = idf(numDocs, docFreq, smoothIdf);
double tfidf = MathUtils.tfidf(tf, idf);
return tfidf;
}
public double idf(double totalDocs, double numTimesWordAppearedInADocument, boolean smooth) {
if(smooth){
return Math.log((1 + totalDocs) / (1 + numTimesWordAppearedInADocument)) + 1.0;
} else {
return Math.log(totalDocs / numTimesWordAppearedInADocument) + 1.0;
}
}
public static double tf(int count) {
return count;
}
public int getMinWordFrequency() {
return minWordFrequency;
}
public void setMinWordFrequency(int minWordFrequency) {
this.minWordFrequency = minWordFrequency;
}
}

View File

@ -1,121 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.metadata;
import org.datavec.api.conf.Configuration;
import org.nd4j.common.util.Index;
public interface VocabCache {
/**
* Increment the number of documents
* @param by
*/
void incrementNumDocs(double by);
/**
* Number of documents
* @return the number of documents
*/
double numDocs();
/**
* Returns a word in the vocab at a particular index
* @param i the index to get
* @return the word at that index in the vocab
*/
String wordAt(int i);
int wordIndex(String word);
/**
* Configuration for initializing
* @param conf the configuration to initialize with
*/
void initialize(Configuration conf);
/**
* Get the word frequency for a word
* @param word the word to get frequency for
* @return the frequency for a given word
*/
double wordFrequency(String word);
/**
* The min word frequency
* needed to be included in the vocab
* (default 5)
* @return the min word frequency to
* be included in the vocab
*/
int minWordFrequency();
/**
* All of the vocab words (ordered)
* note that these are not all the possible tokens
* @return the list of vocab words
*/
Index vocabWords();
/**
* Increment the doc count for a word by 1
* @param word the word to increment the count for
*/
void incrementDocCount(String word);
/**
* Increment the document count for a particular word
* @param word the word to increment the count for
* @param by the amount to increment by
*/
void incrementDocCount(String word, double by);
/**
* Increment a word count by 1
* @param word the word to increment the count for
*/
void incrementCount(String word);
/**
* Increment count for a word
* @param word the word to increment the count for
* @param by the amount to increment by
*/
void incrementCount(String word, double by);
/**
* Number of documents word has occurred in
* @param word the word to get the idf for
*/
double idf(String word);
/**
* Calculate the tfidf of the word given the document frequency
* @param word the word to get frequency for
* @param frequency the frequency
* @return the tfidf for a word
*/
double tfidf(String word, double frequency, boolean smoothIdf);
}

View File

@ -1,125 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.collection.MultiDimensionalMap;
import org.nd4j.common.primitives.Pair;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import java.util.ArrayList;
import java.util.List;
public class ContextLabelRetriever {
private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>";
private static String END_LABEL = "</([A-Za-z]+|\\d+)>";
private ContextLabelRetriever() {}
/**
* Returns a stripped sentence with the indices of words
* with certain kinds of labels.
*
* @param sentence the sentence to process
* @return a pair of a post processed sentence
* with labels stripped and the spans of
* the labels
*/
public static Pair<String, MultiDimensionalMap<Integer, Integer, String>> stringWithLabels(String sentence,
TokenizerFactory tokenizerFactory) {
MultiDimensionalMap<Integer, Integer, String> map = MultiDimensionalMap.newHashBackedMap();
Tokenizer t = tokenizerFactory.create(sentence);
List<String> currTokens = new ArrayList<>();
String currLabel = null;
String endLabel = null;
List<Pair<String, List<String>>> tokensWithSameLabel = new ArrayList<>();
while (t.hasMoreTokens()) {
String token = t.nextToken();
if (token.matches(BEGIN_LABEL)) {
currLabel = token;
//no labels; add these as NONE and begin the new label
if (!currTokens.isEmpty()) {
tokensWithSameLabel.add(new Pair<>("NONE", (List<String>) new ArrayList<>(currTokens)));
currTokens.clear();
}
} else if (token.matches(END_LABEL)) {
if (currLabel == null)
throw new IllegalStateException("Found an ending label with no matching begin label");
endLabel = token;
} else
currTokens.add(token);
if (currLabel != null && endLabel != null) {
currLabel = currLabel.replaceAll("[<>/]", "");
endLabel = endLabel.replaceAll("[<>/]", "");
Preconditions.checkState(!currLabel.isEmpty(), "Current label is empty!");
Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!");
Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s",
currLabel, endLabel);
tokensWithSameLabel.add(new Pair<>(currLabel, (List<String>) new ArrayList<>(currTokens)));
currTokens.clear();
//clear out the tokens
currLabel = null;
endLabel = null;
}
}
//no labels; add these as NONE and begin the new label
if (!currTokens.isEmpty()) {
tokensWithSameLabel.add(new Pair<>("none", (List<String>) new ArrayList<>(currTokens)));
currTokens.clear();
}
//now join the output
StringBuilder strippedSentence = new StringBuilder();
for (Pair<String, List<String>> tokensWithLabel : tokensWithSameLabel) {
String joinedSentence = StringUtils.join(tokensWithLabel.getSecond(), " ");
//spaces between separate parts of the sentence
if (!(strippedSentence.length() < 1))
strippedSentence.append(" ");
strippedSentence.append(joinedSentence);
int begin = strippedSentence.toString().indexOf(joinedSentence);
int end = begin + joinedSentence.length();
map.put(begin, end, tokensWithLabel.getFirst());
}
return new Pair<>(strippedSentence.toString(), map);
}
}

View File

@ -1,60 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.nd4j.common.primitives.Counter;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
public class Util {
/**
* Returns a thread safe counter
*
* @return
*/
public static Counter<String> parallelCounter() {
return new Counter<>();
}
public static boolean matchesAnyStopWord(List<String> stopWords, String word) {
for (String s : stopWords)
if (s.equalsIgnoreCase(word))
return true;
return false;
}
public static Level disableLogging() {
Logger logger = Logger.getLogger("org.apache.uima");
while (logger.getLevel() == null) {
logger = logger.getParent();
}
Level level = logger.getLevel();
logger.setLevel(Level.OFF);
return level;
}
}

View File

@ -1,177 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.apache.commons.lang3.StringUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class Window implements Serializable {
/**
*
*/
private static final long serialVersionUID = 6359906393699230579L;
private List<String> words;
private String label = "NONE";
private boolean beginLabel;
private boolean endLabel;
private int median;
private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>";
private static String END_LABEL = "</([A-Z]+|\\d+)>";
private int begin, end;
/**
* Creates a window with a context of size 3
* @param words a collection of strings of size 3
*/
public Window(Collection<String> words, int begin, int end) {
this(words, 5, begin, end);
}
public String asTokens() {
return StringUtils.join(words, " ");
}
/**
* Initialize a window with the given size
* @param words the words to use
* @param windowSize the size of the window
* @param begin the begin index for the window
* @param end the end index for the window
*/
public Window(Collection<String> words, int windowSize, int begin, int end) {
if (words == null)
throw new IllegalArgumentException("Words must be a list of size 3");
this.words = new ArrayList<>(words);
int windowSize1 = windowSize;
this.begin = begin;
this.end = end;
initContext();
}
private void initContext() {
int median = (int) Math.floor(words.size() / 2);
List<String> begin = words.subList(0, median);
List<String> after = words.subList(median + 1, words.size());
for (String s : begin) {
if (s.matches(BEGIN_LABEL)) {
this.label = s.replaceAll("(<|>)", "").replace("/", "");
beginLabel = true;
} else if (s.matches(END_LABEL)) {
endLabel = true;
this.label = s.replaceAll("(<|>|/)", "").replace("/", "");
}
}
for (String s1 : after) {
if (s1.matches(BEGIN_LABEL)) {
this.label = s1.replaceAll("(<|>)", "").replace("/", "");
beginLabel = true;
}
if (s1.matches(END_LABEL)) {
endLabel = true;
this.label = s1.replaceAll("(<|>)", "");
}
}
this.median = median;
}
@Override
public String toString() {
return words.toString();
}
public List<String> getWords() {
return words;
}
public void setWords(List<String> words) {
this.words = words;
}
public String getWord(int i) {
return words.get(i);
}
public String getFocusWord() {
return words.get(median);
}
public boolean isBeginLabel() {
return !label.equals("NONE") && beginLabel;
}
public boolean isEndLabel() {
return !label.equals("NONE") && endLabel;
}
public String getLabel() {
return label.replace("/", "");
}
public int getWindowSize() {
return words.size();
}
public int getMedian() {
return median;
}
public void setLabel(String label) {
this.label = label;
}
public int getBegin() {
return begin;
}
public void setBegin(int begin) {
this.begin = begin;
}
public int getEnd() {
return end;
}
public void setEnd(int end) {
this.end = end;
}
}

View File

@ -1,188 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.movingwindow;
import org.apache.commons.lang3.StringUtils;
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
public class Windows {
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(InputStream words, int windowSize) {
Tokenizer tokenizer = new DefaultStreamTokenizer(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param tokenizerFactory tokenizer factory to use
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(InputStream words, TokenizerFactory tokenizerFactory, int windowSize) {
Tokenizer tokenizer = tokenizerFactory.create(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
if (list.isEmpty())
throw new IllegalStateException("No tokens found for windows");
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words, int windowSize) {
StringTokenizer tokenizer = new StringTokenizer(words);
List<String> list = new ArrayList<String>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param tokenizerFactory tokenizer factory to use
* @param windowSize the window size to generate
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory, int windowSize) {
Tokenizer tokenizer = tokenizerFactory.create(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
if (list.isEmpty())
throw new IllegalStateException("No tokens found for windows");
return windows(list, windowSize);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words) {
StringTokenizer tokenizer = new StringTokenizer(words);
List<String> list = new ArrayList<String>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, 5);
}
/**
* Constructs a list of window of size windowSize.
* Note that padding for each window is created as well.
* @param words the words to tokenize and construct windows from
* @param tokenizerFactory tokenizer factory to use
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(String words, TokenizerFactory tokenizerFactory) {
Tokenizer tokenizer = tokenizerFactory.create(words);
List<String> list = new ArrayList<>();
while (tokenizer.hasMoreTokens())
list.add(tokenizer.nextToken());
return windows(list, 5);
}
/**
* Creates a sliding window from text
* @param windowSize the window size to use
* @param wordPos the position of the word to center
* @param sentence the sentence to createComplex a window for
* @return a window based on the given sentence
*/
public static Window windowForWordInPosition(int windowSize, int wordPos, List<String> sentence) {
List<String> window = new ArrayList<>();
List<String> onlyTokens = new ArrayList<>();
int contextSize = (int) Math.floor((windowSize - 1) / 2);
for (int i = wordPos - contextSize; i <= wordPos + contextSize; i++) {
if (i < 0)
window.add("<s>");
else if (i >= sentence.size())
window.add("</s>");
else {
onlyTokens.add(sentence.get(i));
window.add(sentence.get(i));
}
}
String wholeSentence = StringUtils.join(sentence);
String window2 = StringUtils.join(onlyTokens);
int begin = wholeSentence.indexOf(window2);
int end = begin + window2.length();
return new Window(window, begin, end);
}
/**
* Constructs a list of window of size windowSize
* @param words the words to construct windows from
* @return the list of windows for the tokenized string
*/
public static List<Window> windows(List<String> words, int windowSize) {
List<Window> ret = new ArrayList<>();
for (int i = 0; i < words.size(); i++)
ret.add(windowForWordInPosition(windowSize, i, words));
return ret;
}
}

View File

@ -1,189 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.reader;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.vector.Vectorizer;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.vectorizer.TfidfVectorizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.util.*;
public class TfidfRecordReader extends FileRecordReader {
private TfidfVectorizer tfidfVectorizer;
private List<Record> records = new ArrayList<>();
private Iterator<Record> recordIter;
private int numFeatures;
private boolean initialized = false;
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
initialize(new Configuration(), split);
}
@Override
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
super.initialize(conf, split);
//train a new one since it hasn't been specified
if (tfidfVectorizer == null) {
tfidfVectorizer = new TfidfVectorizer();
tfidfVectorizer.initialize(conf);
//clear out old strings
records.clear();
INDArray ret = tfidfVectorizer.fitTransform(this, new Vectorizer.RecordCallBack() {
@Override
public void onRecord(Record fullRecord) {
records.add(fullRecord);
}
});
//cache the number of features used for each document
numFeatures = ret.columns();
recordIter = records.iterator();
} else {
records = new ArrayList<>();
//the record reader has 2 phases, we are skipping the
//document frequency phase and just using the super() to get the file contents
//and pass it to the already existing vectorizer.
while (super.hasNext()) {
Record fileContents = super.nextRecord();
INDArray transform = tfidfVectorizer.transform(fileContents);
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
if (appendLabel)
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
records.add(record);
}
recordIter = records.iterator();
}
this.initialized = true;
}
@Override
public void reset() {
if (inputSplit == null)
throw new UnsupportedOperationException("Cannot reset without first initializing");
recordIter = records.iterator();
}
@Override
public Record nextRecord() {
if (recordIter == null)
return super.nextRecord();
return recordIter.next();
}
@Override
public List<Writable> next() {
return nextRecord().getRecord();
}
@Override
public boolean hasNext() {
//we aren't done vectorizing yet
if (recordIter == null)
return super.hasNext();
return recordIter.hasNext();
}
@Override
public void close() throws IOException {
}
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public Configuration getConf() {
return conf;
}
public TfidfVectorizer getTfidfVectorizer() {
return tfidfVectorizer;
}
public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) {
if (initialized) {
throw new IllegalArgumentException(
"Setting TfidfVectorizer after TfidfRecordReader initialization doesn't have an effect");
}
this.tfidfVectorizer = tfidfVectorizer;
}
public int getNumFeatures() {
return numFeatures;
}
public void shuffle() {
this.shuffle(new Random());
}
public void shuffle(Random random) {
Collections.shuffle(this.records, random);
this.reset();
}
@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0);
}
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
List<Record> out = new ArrayList<>();
for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) {
INDArray transform = tfidfVectorizer.transform(fileContents);
org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));
if (appendLabel)
record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
out.add(record);
}
return out;
}
}

View File

@ -1,44 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.stopwords;
import org.apache.commons.io.IOUtils;
import java.io.IOException;
import java.util.List;
public class StopWords {
private static List<String> stopWords;
@SuppressWarnings("unchecked")
public static List<String> getStopWords() {
try {
if (stopWords == null)
stopWords = IOUtils.readLines(StopWords.class.getResourceAsStream("/stopwords"));
} catch (IOException e) {
throw new RuntimeException(e);
}
return stopWords;
}
}

View File

@ -1,125 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import opennlp.tools.tokenize.TokenizerME;
import opennlp.tools.tokenize.TokenizerModel;
import opennlp.tools.util.Span;
import opennlp.uima.tokenize.AbstractTokenizer;
import opennlp.uima.tokenize.TokenizerModelResource;
import opennlp.uima.util.AnnotatorUtil;
import opennlp.uima.util.UimaUtil;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.Feature;
import org.apache.uima.cas.TypeSystem;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.resource.ResourceAccessException;
import org.apache.uima.resource.ResourceInitializationException;
public class ConcurrentTokenizer extends AbstractTokenizer {
/**
* The OpenNLP tokenizer.
*/
private TokenizerME tokenizer;
private Feature probabilityFeature;
@Override
public synchronized void process(CAS cas) throws AnalysisEngineProcessException {
super.process(cas);
}
/**
* Initializes a new instance.
*
* Note: Use {@link #initialize(UimaContext) } to initialize
* this instance. Not use the constructor.
*/
public ConcurrentTokenizer() {
super("OpenNLP Tokenizer");
// must not be implemented !
}
/**
* Initializes the current instance with the given context.
*
* Note: Do all initialization in this method, do not use the constructor.
*/
public void initialize(UimaContext context) throws ResourceInitializationException {
super.initialize(context);
TokenizerModel model;
try {
TokenizerModelResource modelResource =
(TokenizerModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER);
model = modelResource.getModel();
} catch (ResourceAccessException e) {
throw new ResourceInitializationException(e);
}
tokenizer = new TokenizerME(model);
}
/**
* Initializes the type system.
*/
public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException {
super.typeSystemInit(typeSystem);
probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(context, tokenType,
UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE);
}
@Override
protected Span[] tokenize(CAS cas, AnnotationFS sentence) {
return tokenizer.tokenizePos(sentence.getCoveredText());
}
@Override
protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) {
// if interest
if (probabilityFeature != null) {
double tokenProbabilties[] = tokenizer.getTokenProbabilities();
for (int i = 0; i < tokenAnnotations.length; i++) {
tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]);
}
}
}
/**
* Releases allocated resources.
*/
public void destroy() {
// dereference model to allow garbage collection
tokenizer = null;
}
}

View File

@ -1,107 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
/**
* Tokenizer based on the {@link java.io.StreamTokenizer}
* @author Adam Gibson
*
*/
public class DefaultStreamTokenizer implements Tokenizer {
private StreamTokenizer streamTokenizer;
private TokenPreProcess tokenPreProcess;
public DefaultStreamTokenizer(InputStream is) {
Reader r = new BufferedReader(new InputStreamReader(is));
streamTokenizer = new StreamTokenizer(r);
}
@Override
public boolean hasMoreTokens() {
if (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
try {
streamTokenizer.nextToken();
} catch (IOException e1) {
throw new RuntimeException(e1);
}
}
return streamTokenizer.ttype != StreamTokenizer.TT_EOF && streamTokenizer.ttype != -1;
}
@Override
public int countTokens() {
return getTokens().size();
}
@Override
public String nextToken() {
StringBuilder sb = new StringBuilder();
if (streamTokenizer.ttype == StreamTokenizer.TT_WORD) {
sb.append(streamTokenizer.sval);
} else if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
sb.append(streamTokenizer.nval);
} else if (streamTokenizer.ttype == StreamTokenizer.TT_EOL) {
try {
while (streamTokenizer.ttype == StreamTokenizer.TT_EOL)
streamTokenizer.nextToken();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
else if (hasMoreTokens())
return nextToken();
String ret = sb.toString();
if (tokenPreProcess != null)
ret = tokenPreProcess.preProcess(ret);
return ret;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
this.tokenPreProcess = tokenPreProcessor;
}
}

View File

@ -1,75 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
/**
* Default tokenizer
* @author Adam Gibson
*/
public class DefaultTokenizer implements Tokenizer {
public DefaultTokenizer(String tokens) {
tokenizer = new StringTokenizer(tokens);
}
private StringTokenizer tokenizer;
private TokenPreProcess tokenPreProcess;
@Override
public boolean hasMoreTokens() {
return tokenizer.hasMoreTokens();
}
@Override
public int countTokens() {
return tokenizer.countTokens();
}
@Override
public String nextToken() {
String base = tokenizer.nextToken();
if (tokenPreProcess != null)
base = tokenPreProcess.preProcess(base);
return base;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
this.tokenPreProcess = tokenPreProcessor;
}
}

View File

@ -1,136 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.cas.CAS;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.util.JCasUtil;
import org.cleartk.token.type.Sentence;
import org.cleartk.token.type.Token;
import org.datavec.nlp.annotator.PoStagger;
import org.datavec.nlp.annotator.SentenceAnnotator;
import org.datavec.nlp.annotator.StemmerAnnotator;
import org.datavec.nlp.annotator.TokenizerAnnotator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class PosUimaTokenizer implements Tokenizer {
private static AnalysisEngine engine;
private List<String> tokens;
private Collection<String> allowedPosTags;
private int index;
private static CAS cas;
public PosUimaTokenizer(String tokens, AnalysisEngine engine, Collection<String> allowedPosTags) {
if (engine == null)
PosUimaTokenizer.engine = engine;
this.allowedPosTags = allowedPosTags;
this.tokens = new ArrayList<>();
try {
if (cas == null)
cas = engine.newCAS();
cas.reset();
cas.setDocumentText(tokens);
PosUimaTokenizer.engine.process(cas);
for (Sentence s : JCasUtil.select(cas.getJCas(), Sentence.class)) {
for (Token t : JCasUtil.selectCovered(Token.class, s)) {
//add NONE for each invalid token
if (valid(t))
if (t.getLemma() != null)
this.tokens.add(t.getLemma());
else if (t.getStem() != null)
this.tokens.add(t.getStem());
else
this.tokens.add(t.getCoveredText());
else
this.tokens.add("NONE");
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private boolean valid(Token token) {
String check = token.getCoveredText();
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
return false;
else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos()))
return false;
return true;
}
@Override
public boolean hasMoreTokens() {
return index < tokens.size();
}
@Override
public int countTokens() {
return tokens.size();
}
@Override
public String nextToken() {
String ret = tokens.get(index);
index++;
return ret;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<String>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
public static AnalysisEngine defaultAnalysisEngine() {
try {
return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription(
SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(),
PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English")));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
// TODO Auto-generated method stub
}
}

View File

@ -1,34 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
public interface TokenPreProcess {
/**
* Pre process a token
* @param token the token to pre process
* @return the preprocessed token
*/
String preProcess(String token);
}

View File

@ -1,61 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import java.util.List;
public interface Tokenizer {
/**
* An iterator for tracking whether
* more tokens are left in the iterator not
* @return whether there is anymore tokens
* to iterate over
*/
boolean hasMoreTokens();
/**
* The number of tokens in the tokenizer
* @return the number of tokens
*/
int countTokens();
/**
* The next token (word usually) in the string
* @return the next token in the string if any
*/
String nextToken();
/**
* Returns a list of all the tokens
* @return a list of all the tokens
*/
List<String> getTokens();
/**
* Set the token pre process
* @param tokenPreProcessor the token pre processor to set
*/
void setTokenPreProcessor(TokenPreProcess tokenPreProcessor);
}

View File

@ -1,123 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer;
import org.apache.uima.cas.CAS;
import org.apache.uima.fit.util.JCasUtil;
import org.cleartk.token.type.Token;
import org.datavec.nlp.uima.UimaResource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
* Tokenizer based on the passed in analysis engine
* @author Adam Gibson
*
*/
public class UimaTokenizer implements Tokenizer {
private List<String> tokens;
private int index;
private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class);
private boolean checkForLabel;
private TokenPreProcess tokenPreProcessor;
public UimaTokenizer(String tokens, UimaResource resource, boolean checkForLabel) {
this.checkForLabel = checkForLabel;
this.tokens = new ArrayList<>();
try {
CAS cas = resource.process(tokens);
Collection<Token> tokenList = JCasUtil.select(cas.getJCas(), Token.class);
for (Token t : tokenList) {
if (!checkForLabel || valid(t.getCoveredText()))
if (t.getLemma() != null)
this.tokens.add(t.getLemma());
else if (t.getStem() != null)
this.tokens.add(t.getStem());
else
this.tokens.add(t.getCoveredText());
}
resource.release(cas);
} catch (Exception e) {
log.error("",e);
throw new RuntimeException(e);
}
}
private boolean valid(String check) {
if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>"))
return false;
return true;
}
@Override
public boolean hasMoreTokens() {
return index < tokens.size();
}
@Override
public int countTokens() {
return tokens.size();
}
@Override
public String nextToken() {
String ret = tokens.get(index);
index++;
if (tokenPreProcessor != null) {
ret = tokenPreProcessor.preProcess(ret);
}
return ret;
}
@Override
public List<String> getTokens() {
List<String> tokens = new ArrayList<>();
while (hasMoreTokens()) {
tokens.add(nextToken());
}
return tokens;
}
@Override
public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
this.tokenPreProcessor = tokenPreProcessor;
}
}

View File

@ -1,47 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
/**
* Gets rid of endings:
*
* ed,ing, ly, s, .
* @author Adam Gibson
*/
public class EndingPreProcessor implements TokenPreProcess {
@Override
public String preProcess(String token) {
if (token.endsWith("s") && !token.endsWith("ss"))
token = token.substring(0, token.length() - 1);
if (token.endsWith("."))
token = token.substring(0, token.length() - 1);
if (token.endsWith("ed"))
token = token.substring(0, token.length() - 2);
if (token.endsWith("ing"))
token = token.substring(0, token.length() - 3);
if (token.endsWith("ly"))
token = token.substring(0, token.length() - 2);
return token;
}
}

View File

@ -1,30 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizer.preprocessor;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
public class LowerCasePreProcessor implements TokenPreProcess {
@Override
public String preProcess(String token) {
return token.toLowerCase();
}
}

View File

@ -1,60 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer;
import org.datavec.nlp.tokenization.tokenizer.DefaultTokenizer;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import java.io.InputStream;
/**
* Default tokenizer based on string tokenizer or stream tokenizer
* @author Adam Gibson
*/
public class DefaultTokenizerFactory implements TokenizerFactory {
private TokenPreProcess tokenPreProcess;
@Override
public Tokenizer create(String toTokenize) {
DefaultTokenizer t = new DefaultTokenizer(toTokenize);
t.setTokenPreProcessor(tokenPreProcess);
return t;
}
@Override
public Tokenizer create(InputStream toTokenize) {
Tokenizer t = new DefaultStreamTokenizer(toTokenize);
t.setTokenPreProcessor(tokenPreProcess);
return t;
}
@Override
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
this.tokenPreProcess = preProcessor;
}
}

View File

@ -1,85 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.datavec.nlp.annotator.PoStagger;
import org.datavec.nlp.annotator.SentenceAnnotator;
import org.datavec.nlp.annotator.StemmerAnnotator;
import org.datavec.nlp.annotator.TokenizerAnnotator;
import org.datavec.nlp.tokenization.tokenizer.PosUimaTokenizer;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import java.io.InputStream;
import java.util.Collection;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
public class PosUimaTokenizerFactory implements TokenizerFactory {
private AnalysisEngine tokenizer;
private Collection<String> allowedPoSTags;
private TokenPreProcess tokenPreProcess;
public PosUimaTokenizerFactory(Collection<String> allowedPoSTags) {
this(defaultAnalysisEngine(), allowedPoSTags);
}
public PosUimaTokenizerFactory(AnalysisEngine tokenizer, Collection<String> allowedPosTags) {
this.tokenizer = tokenizer;
this.allowedPoSTags = allowedPosTags;
}
public static AnalysisEngine defaultAnalysisEngine() {
try {
return createEngine(createEngineDescription(SentenceAnnotator.getDescription(),
TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"),
StemmerAnnotator.getDescription("English")));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Tokenizer create(String toTokenize) {
PosUimaTokenizer t = new PosUimaTokenizer(toTokenize, tokenizer, allowedPoSTags);
t.setTokenPreProcessor(tokenPreProcess);
return t;
}
@Override
public Tokenizer create(InputStream toTokenize) {
throw new UnsupportedOperationException();
}
@Override
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
this.tokenPreProcess = preProcessor;
}
}

View File

@ -1,62 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.InputStream;
/**
* Generates a tokenizer for a given string
* @author Adam Gibson
*
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface TokenizerFactory {
/**
* The tokenizer to createComplex
* @param toTokenize the string to createComplex the tokenizer with
* @return the new tokenizer
*/
Tokenizer create(String toTokenize);
/**
* Create a tokenizer based on an input stream
* @param toTokenize
* @return
*/
Tokenizer create(InputStream toTokenize);
/**
* Sets a token pre processor to be used
* with every tokenizer
* @param preProcessor the token pre processor to use
*/
void setTokenPreProcessor(TokenPreProcess preProcessor);
}

View File

@ -1,138 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.tokenization.tokenizerfactory;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.resource.ResourceInitializationException;
import org.datavec.nlp.annotator.SentenceAnnotator;
import org.datavec.nlp.annotator.TokenizerAnnotator;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizer.UimaTokenizer;
import org.datavec.nlp.uima.UimaResource;
import java.io.InputStream;
/**
* Uses a uima {@link AnalysisEngine} to
* tokenize text.
*
*
* @author Adam Gibson
*
*/
public class UimaTokenizerFactory implements TokenizerFactory {
private UimaResource uimaResource;
private boolean checkForLabel;
private static AnalysisEngine defaultAnalysisEngine;
private TokenPreProcess preProcess;
public UimaTokenizerFactory() throws ResourceInitializationException {
this(defaultAnalysisEngine(), true);
}
public UimaTokenizerFactory(UimaResource resource) {
this(resource, true);
}
public UimaTokenizerFactory(AnalysisEngine tokenizer) {
this(tokenizer, true);
}
public UimaTokenizerFactory(UimaResource resource, boolean checkForLabel) {
this.uimaResource = resource;
this.checkForLabel = checkForLabel;
}
public UimaTokenizerFactory(boolean checkForLabel) throws ResourceInitializationException {
this(defaultAnalysisEngine(), checkForLabel);
}
public UimaTokenizerFactory(AnalysisEngine tokenizer, boolean checkForLabel) {
super();
this.checkForLabel = checkForLabel;
try {
this.uimaResource = new UimaResource(tokenizer);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Tokenizer create(String toTokenize) {
if (toTokenize == null || toTokenize.isEmpty())
throw new IllegalArgumentException("Unable to proceed; on sentence to tokenize");
Tokenizer ret = new UimaTokenizer(toTokenize, uimaResource, checkForLabel);
ret.setTokenPreProcessor(preProcess);
return ret;
}
public UimaResource getUimaResource() {
return uimaResource;
}
/**
* Creates a tokenization,/stemming pipeline
* @return a tokenization/stemming pipeline
*/
public static AnalysisEngine defaultAnalysisEngine() {
try {
if (defaultAnalysisEngine == null)
defaultAnalysisEngine = AnalysisEngineFactory.createEngine(
AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.getDescription(),
TokenizerAnnotator.getDescription()));
return defaultAnalysisEngine;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Tokenizer create(InputStream toTokenize) {
throw new UnsupportedOperationException();
}
@Override
public void setTokenPreProcessor(TokenPreProcess preProcessor) {
this.preProcess = preProcessor;
}
}

View File

@ -1,69 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import org.datavec.api.transform.Transform;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.util.List;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface BagOfWordsTransform extends Transform {
/**
* The output shape of the transform (usually 1 x number of words)
* @return
*/
long[] outputShape();
/**
* The vocab words in the transform.
* This is the words that were accumulated
* when building a vocabulary.
* (This is generally associated with some form of
* mininmum words frequency scanning to build a vocab
* you then map on to a list of vocab words as a list)
* @return the vocab words for the transform
*/
List<String> vocabWords();
/**
* Transform for a list of tokens
* that are objects. This is to allow loose
* typing for tokens that are unique (non string)
* @param tokens the token objects to transform
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
*/
INDArray transformFromObject(List<List<Object>> tokens);
/**
* Transform for a list of tokens
* that are {@link Writable} (Generally {@link org.datavec.api.writable.Text}
* @param tokens the token objects to transform
* @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array)
*/
INDArray transformFrom(List<List<Writable>> tokens);
}

View File

@ -1,24 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
public class BaseWordMapTransform {
}

View File

@ -1,146 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@Data
@EqualsAndHashCode(callSuper = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonIgnoreProperties({"gazeteer"})
public class GazeteerTransform extends BaseColumnTransform implements BagOfWordsTransform {
private String newColumnName;
private List<String> wordList;
private Set<String> gazeteer;
@JsonCreator
public GazeteerTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("newColumnName")String newColumnName,
@JsonProperty("wordList") List<String> wordList) {
super(columnName);
this.newColumnName = newColumnName;
this.wordList = wordList;
this.gazeteer = new HashSet<>(wordList);
}
@Override
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
return new NDArrayMetaData(newName,new long[]{wordList.size()});
}
@Override
public Writable map(Writable columnWritable) {
throw new UnsupportedOperationException();
}
@Override
public Object mapSequence(Object sequence) {
List<List<Object>> sequenceInput = (List<List<Object>>) sequence;
INDArray ret = Nd4j.create(DataType.FLOAT, wordList.size());
for(List<Object> list : sequenceInput) {
for(Object token : list) {
String s = token.toString();
if(gazeteer.contains(s)) {
ret.putScalar(wordList.indexOf(s),1);
}
}
}
return ret;
}
@Override
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
INDArray arr = (INDArray) mapSequence((Object) sequence);
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(arr)));
}
@Override
public String toString() {
return newColumnName;
}
@Override
public Object map(Object input) {
return gazeteer.contains(input.toString());
}
@Override
public String outputColumnName() {
return newColumnName;
}
@Override
public String[] outputColumnNames() {
return new String[]{newColumnName};
}
@Override
public String[] columnNames() {
return new String[]{columnName()};
}
@Override
public String columnName() {
return columnName;
}
@Override
public long[] outputShape() {
return new long[]{wordList.size()};
}
@Override
public List<String> vocabWords() {
return wordList;
}
@Override
public INDArray transformFromObject(List<List<Object>> tokens) {
return (INDArray) mapSequence(tokens);
}
@Override
public INDArray transformFrom(List<List<Writable>> tokens) {
return (INDArray) mapSequence((Object) tokens);
}
}

View File

@ -1,150 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.list.NDArrayList;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.List;
public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform {
private BagOfWordsTransform[] transforms;
private String newColumnName;
private List<String> vocabWords;
/**
*
* @param columnName
* @param transforms
* @param newColumnName
*/
@JsonCreator
public MultiNlpTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("transforms") BagOfWordsTransform[] transforms,
@JsonProperty("newColumnName") String newColumnName) {
super(columnName);
this.transforms = transforms;
this.vocabWords = transforms[0].vocabWords();
if(transforms.length > 1) {
for(int i = 1; i < transforms.length; i++) {
if(!transforms[i].vocabWords().equals(vocabWords)) {
throw new IllegalArgumentException("Vocab words not consistent across transforms!");
}
}
}
this.newColumnName = newColumnName;
}
@Override
public Object mapSequence(Object sequence) {
NDArrayList ndArrayList = new NDArrayList();
for(BagOfWordsTransform bagofWordsTransform : transforms) {
ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List<List<Object>>) sequence)));
}
return ndArrayList.array();
}
@Override
public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(transformFrom(sequence))));
}
@Override
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
return new NDArrayMetaData(newName,outputShape());
}
@Override
public Writable map(Writable columnWritable) {
throw new UnsupportedOperationException("Only able to add for time series");
}
@Override
public String toString() {
return newColumnName;
}
@Override
public Object map(Object input) {
throw new UnsupportedOperationException("Only able to add for time series");
}
@Override
public long[] outputShape() {
long[] ret = new long[transforms[0].outputShape().length];
int validatedRank = transforms[0].outputShape().length;
for(int i = 1; i < transforms.length; i++) {
if(transforms[i].outputShape().length != validatedRank) {
throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank);
}
}
for(int i = 0; i < transforms.length; i++) {
for(int j = 0; j < validatedRank; j++)
ret[j] += transforms[i].outputShape()[j];
}
return ret;
}
@Override
public List<String> vocabWords() {
return vocabWords;
}
@Override
public INDArray transformFromObject(List<List<Object>> tokens) {
NDArrayList ndArrayList = new NDArrayList();
for(BagOfWordsTransform bagofWordsTransform : transforms) {
INDArray arr2 = bagofWordsTransform.transformFromObject(tokens);
arr2 = arr2.reshape(arr2.length());
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
ndArrayList.addAll(newList); }
return ndArrayList.array();
}
@Override
public INDArray transformFrom(List<List<Writable>> tokens) {
NDArrayList ndArrayList = new NDArrayList();
for(BagOfWordsTransform bagofWordsTransform : transforms) {
INDArray arr2 = bagofWordsTransform.transformFrom(tokens);
arr2 = arr2.reshape(arr2.length());
NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
ndArrayList.addAll(newList);
}
return ndArrayList.array();
}
}

View File

@ -1,226 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.util.MathUtils;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = true, exclude = {"tokenizerFactory"})
@JsonInclude(JsonInclude.Include.NON_NULL)
public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform {
private String newColumName;
private Map<String,Integer> wordIndexMap;
private Map<String,Double> weightMap;
private boolean exceptionOnUnknown;
private String tokenizerFactoryClass;
private String preprocessorClass;
private TokenizerFactory tokenizerFactory;
@JsonCreator
public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("newColumnName") String newColumnName,
@JsonProperty("wordIndexMap") Map<String,Integer> wordIndexMap,
@JsonProperty("idfMap") Map<String,Double> idfMap,
@JsonProperty("exceptionOnUnknown") boolean exceptionOnUnknown,
@JsonProperty("tokenizerFactoryClass") String tokenizerFactoryClass,
@JsonProperty("preprocessorClass") String preprocessorClass) {
super(columnName);
this.newColumName = newColumnName;
this.wordIndexMap = wordIndexMap;
this.exceptionOnUnknown = exceptionOnUnknown;
this.weightMap = idfMap;
this.tokenizerFactoryClass = tokenizerFactoryClass;
this.preprocessorClass = preprocessorClass;
if(this.tokenizerFactoryClass == null) {
this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName();
}
try {
tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance();
} catch (Exception e) {
throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
}
if(preprocessorClass != null){
try {
TokenPreProcess tpp = (TokenPreProcess) Class.forName(this.preprocessorClass).newInstance();
tokenizerFactory.setTokenPreProcessor(tpp);
} catch (Exception e){
throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
}
}
}
@Override
public List<Writable> map(List<Writable> writables) {
Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName));
List<Writable> ret = new ArrayList<>(writables);
ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString())));
return ret;
}
@Override
public Object map(Object input) {
return convert(input.toString());
}
@Override
public Object mapSequence(Object sequence) {
return convert(sequence.toString());
}
@Override
public Schema transform(Schema inputSchema) {
Schema.Builder newSchema = new Schema.Builder();
for(int i = 0; i < inputSchema.numColumns(); i++) {
if(inputSchema.getName(i).equals(this.columnName)) {
newSchema.addColumnNDArray(newColumName,new long[]{1,wordIndexMap.size()});
}
else {
newSchema.addColumn(inputSchema.getMetaData(i));
}
}
return newSchema.build();
}
/**
* Convert the given text
* in to an {@link INDArray}
* using the {@link TokenizerFactory}
* specified in the constructor.
* @param text the text to transform
* @return the created {@link INDArray}
* based on the {@link #wordIndexMap} for the column indices
* of the word.
*/
public INDArray convert(String text) {
Tokenizer tokenizer = tokenizerFactory.create(text);
List<String> tokens = tokenizer.getTokens();
INDArray create = Nd4j.create(1,wordIndexMap.size());
Counter<String> tokenizedCounter = new Counter<>();
for(int i = 0; i < tokens.size(); i++) {
tokenizedCounter.incrementCount(tokens.get(i),1.0);
}
for(int i = 0; i < tokens.size(); i++) {
if(wordIndexMap.containsKey(tokens.get(i))) {
int idx = wordIndexMap.get(tokens.get(i));
int count = (int) tokenizedCounter.getCount(tokens.get(i));
double weight = tfidfWord(tokens.get(i),count,tokens.size());
create.putScalar(idx,weight);
}
}
return create;
}
/**
* Calculate the tifdf for a word
* given the word, word count, and document length
* @param word the word to calculate
* @param wordCount the word frequency
* @param documentLength the number of words in the document
* @return the tfidf weight for a given word
*/
public double tfidfWord(String word, long wordCount, long documentLength) {
double tf = tfForWord(wordCount, documentLength);
double idf = idfForWord(word);
return MathUtils.tfidf(tf, idf);
}
/**
* Calculate the weight term frequency for a given
* word normalized by the dcoument length
* @param wordCount the word frequency
* @param documentLength the number of words in the edocument
* @return
*/
private double tfForWord(long wordCount, long documentLength) {
return wordCount;
}
private double idfForWord(String word) {
if(weightMap.containsKey(word))
return weightMap.get(word);
return 0;
}
@Override
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
return new NDArrayMetaData(outputColumnName(),new long[]{1,wordIndexMap.size()});
}
@Override
public String outputColumnName() {
return newColumName;
}
@Override
public String[] outputColumnNames() {
return new String[]{newColumName};
}
@Override
public String[] columnNames() {
return new String[]{columnName()};
}
@Override
public String columnName() {
return columnName;
}
@Override
public Writable map(Writable columnWritable) {
return new NDArrayWritable(convert(columnWritable.toString()));
}
}

View File

@ -1,107 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.uima;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.CasPool;
public class UimaResource {
private AnalysisEngine analysisEngine;
private CasPool casPool;
public UimaResource(AnalysisEngine analysisEngine) throws ResourceInitializationException {
this.analysisEngine = analysisEngine;
this.casPool = new CasPool(Runtime.getRuntime().availableProcessors() * 10, analysisEngine);
}
public UimaResource(AnalysisEngine analysisEngine, CasPool casPool) {
this.analysisEngine = analysisEngine;
this.casPool = casPool;
}
public AnalysisEngine getAnalysisEngine() {
return analysisEngine;
}
public void setAnalysisEngine(AnalysisEngine analysisEngine) {
this.analysisEngine = analysisEngine;
}
public CasPool getCasPool() {
return casPool;
}
public void setCasPool(CasPool casPool) {
this.casPool = casPool;
}
/**
* Use the given analysis engine and process the given text
* You must release the return cas yourself
* @param text the text to rpocess
* @return the processed cas
*/
public CAS process(String text) {
CAS cas = retrieve();
cas.setDocumentText(text);
try {
analysisEngine.process(cas);
} catch (AnalysisEngineProcessException e) {
if (text != null && !text.isEmpty())
return process(text);
throw new RuntimeException(e);
}
return cas;
}
public CAS retrieve() {
CAS ret = casPool.getCas();
try {
return ret == null ? analysisEngine.newCAS() : ret;
} catch (ResourceInitializationException e) {
throw new RuntimeException(e);
}
}
public void release(CAS cas) {
casPool.releaseCas(cas);
}
}

View File

@ -1,77 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.vectorizer;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import java.util.HashSet;
import java.util.Set;
public abstract class AbstractTfidfVectorizer<VECTOR_TYPE> extends TextVectorizer<VECTOR_TYPE> {
@Override
public void doWithTokens(Tokenizer tokenizer) {
Set<String> seen = new HashSet<>();
while (tokenizer.hasMoreTokens()) {
String token = tokenizer.nextToken();
if (!stopWords.contains(token)) {
cache.incrementCount(token);
if (!seen.contains(token)) {
cache.incrementDocCount(token);
}
seen.add(token);
}
}
}
@Override
public TokenizerFactory createTokenizerFactory(Configuration conf) {
String clazz = conf.get(TOKENIZER, DefaultTokenizerFactory.class.getName());
try {
Class<? extends TokenizerFactory> tokenizerFactoryClazz =
(Class<? extends TokenizerFactory>) Class.forName(clazz);
TokenizerFactory tf = tokenizerFactoryClazz.newInstance();
String preproc = conf.get(PREPROCESSOR, null);
if(preproc != null){
TokenPreProcess tpp = (TokenPreProcess) Class.forName(preproc).newInstance();
tf.setTokenPreProcessor(tpp);
}
return tf;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public abstract VECTOR_TYPE createVector(Object[] args);
@Override
public abstract VECTOR_TYPE fitTransform(RecordReader reader);
@Override
public abstract VECTOR_TYPE transform(Record record);
}

View File

@ -1,121 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.vectorizer;
import lombok.Getter;
import org.nd4j.common.primitives.Counter;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.vector.Vectorizer;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.metadata.DefaultVocabCache;
import org.datavec.nlp.metadata.VocabCache;
import org.datavec.nlp.stopwords.StopWords;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import java.util.Collection;
public abstract class TextVectorizer<VECTOR_TYPE> implements Vectorizer<VECTOR_TYPE> {
protected TokenizerFactory tokenizerFactory;
protected int minWordFrequency = 0;
public final static String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency";
public final static String STOP_WORDS = "org.nd4j.nlp.stopwords";
public final static String TOKENIZER = "org.datavec.nlp.tokenizerfactory";
public static final String PREPROCESSOR = "org.datavec.nlp.preprocessor";
public final static String VOCAB_CACHE = "org.datavec.nlp.vocabcache";
protected Collection<String> stopWords;
@Getter
protected VocabCache cache;
@Override
public void initialize(Configuration conf) {
tokenizerFactory = createTokenizerFactory(conf);
minWordFrequency = conf.getInt(MIN_WORD_FREQUENCY, 5);
if(conf.get(STOP_WORDS) != null)
stopWords = conf.getStringCollection(STOP_WORDS);
if (stopWords == null)
stopWords = StopWords.getStopWords();
String clazz = conf.get(VOCAB_CACHE, DefaultVocabCache.class.getName());
try {
Class<? extends VocabCache> tokenizerFactoryClazz = (Class<? extends VocabCache>) Class.forName(clazz);
cache = tokenizerFactoryClazz.newInstance();
cache.initialize(conf);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void fit(RecordReader reader) {
fit(reader, null);
}
@Override
public void fit(RecordReader reader, RecordCallBack callBack) {
while (reader.hasNext()) {
Record record = reader.nextRecord();
String s = toString(record.getRecord());
Tokenizer tokenizer = tokenizerFactory.create(s);
doWithTokens(tokenizer);
if (callBack != null)
callBack.onRecord(record);
cache.incrementNumDocs(1);
}
}
protected Counter<String> wordFrequenciesForRecord(Collection<Writable> record) {
String s = toString(record);
Tokenizer tokenizer = tokenizerFactory.create(s);
Counter<String> ret = new Counter<>();
while (tokenizer.hasMoreTokens())
ret.incrementCount(tokenizer.nextToken(), 1.0);
return ret;
}
protected String toString(Collection<Writable> record) {
StringBuilder sb = new StringBuilder();
for(Writable w : record){
sb.append(w.toString());
}
return sb.toString();
}
/**
* Increment counts, add to collection,...
* @param tokenizer
*/
public abstract void doWithTokens(Tokenizer tokenizer);
/**
* Create tokenizer factory based on the configuration
* @param conf the configuration to use
* @return the tokenizer factory based on the configuration
*/
public abstract TokenizerFactory createTokenizerFactory(Configuration conf);
}

View File

@ -1,105 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.vectorizer;
import org.datavec.api.conf.Configuration;
import org.nd4j.common.primitives.Counter;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.writable.NDArrayWritable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class TfidfVectorizer extends AbstractTfidfVectorizer<INDArray> {
/**
* Default: True.<br>
* If true: use idf(d, t) = log [ (1 + n) / (1 + df(d, t)) ] + 1<br>
* If false: use idf(t) = log [ n / df(t) ] + 1<br>
*/
public static final String SMOOTH_IDF = "org.datavec.nlp.TfidfVectorizer.smooth_idf";
protected boolean smooth_idf;
@Override
public INDArray createVector(Object[] args) {
Counter<String> docFrequencies = (Counter<String>) args[0];
double[] vector = new double[cache.vocabWords().size()];
for (int i = 0; i < cache.vocabWords().size(); i++) {
String word = cache.wordAt(i);
double freq = docFrequencies.getCount(word);
vector[i] = cache.tfidf(word, freq, smooth_idf);
}
return Nd4j.create(vector);
}
@Override
public INDArray fitTransform(RecordReader reader) {
return fitTransform(reader, null);
}
@Override
public INDArray fitTransform(final RecordReader reader, RecordCallBack callBack) {
final List<Record> records = new ArrayList<>();
fit(reader, new RecordCallBack() {
@Override
public void onRecord(Record record) {
records.add(record);
}
});
if (records.isEmpty())
throw new IllegalStateException("No records found!");
INDArray ret = Nd4j.create(records.size(), cache.vocabWords().size());
int i = 0;
for (Record record : records) {
INDArray transformed = transform(record);
org.datavec.api.records.impl.Record transformedRecord = new org.datavec.api.records.impl.Record(
Arrays.asList(new NDArrayWritable(transformed),
record.getRecord().get(record.getRecord().size() - 1)),
new RecordMetaDataURI(record.getMetaData().getURI(), reader.getClass()));
ret.putRow(i++, transformed);
if (callBack != null) {
callBack.onRecord(transformedRecord);
}
}
return ret;
}
@Override
public INDArray transform(Record record) {
Counter<String> wordFrequencies = wordFrequenciesForRecord(record.getRecord());
return createVector(new Object[] {wordFrequencies});
}
@Override
public void initialize(Configuration conf){
super.initialize(conf);
this.smooth_idf = conf.getBoolean(SMOOTH_IDF, true);
}
}

View File

@ -1,194 +0,0 @@
a
----s
act
"the
"The
about
above
after
again
against
all
am
an
and
any
are
aren't
as
at
be
because
been
before
being
below
between
both
but
by
can't
cannot
could
couldn't
did
didn't
do
does
doesn't
doing
don't
down
during
each
few
for
from
further
had
hadn't
has
hasn't
have
haven't
having
he
he'd
he'll
he's
her
here
here's
hers
herself
him
himself
his
how
how's
i
i'd
i'll
i'm
i've
if
in
into
is
isn't
it
it's
its
itself
let's
me
more
most
mustn't
my
myself
no
nor
not
of
off
on
once
only
or
other
ought
our
ours
ourselves
out
over
own
put
same
shan't
she
she'd
she'll
she's
should
somebody
something
shouldn't
so
some
such
take
than
that
that's
the
their
theirs
them
themselves
then
there
there's
these
they
they'd
they'll
they're
they've
this
those
through
to
too
under
until
up
very
was
wasn't
we
we'd
we'll
we're
we've
were
weren't
what
what's
when
when's
where
where's
which
while
who
who's
whom
why
why's
will
with
without
won't
would
wouldn't
you
you'd
you'll
you're
you've
your
yours
yourself
yourselves
.
?
!
,
+
=
also
-
;
:

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.nlp";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,130 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.reader;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.vectorizer.TfidfVectorizer;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.net.URI;
import java.util.*;
import static org.junit.Assert.*;
/**
* @author Adam Gibson
*/
public class TfidfRecordReaderTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testReader() throws Exception {
TfidfVectorizer vectorizer = new TfidfVectorizer();
Configuration conf = new Configuration();
conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1);
conf.setBoolean(RecordReader.APPEND_LABEL, true);
vectorizer.initialize(conf);
TfidfRecordReader reader = new TfidfRecordReader();
File f = testDir.newFolder();
new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f);
List<URI> u = new ArrayList<>();
for(File f2 : f.listFiles()){
if(f2.isDirectory()){
for(File f3 : f2.listFiles()){
u.add(f3.toURI());
}
} else {
u.add(f2.toURI());
}
}
Collections.sort(u);
CollectionInputSplit c = new CollectionInputSplit(u);
reader.initialize(conf, c);
int count = 0;
int[] labelAssertions = new int[3];
while (reader.hasNext()) {
Collection<Writable> record = reader.next();
Iterator<Writable> recordIter = record.iterator();
NDArrayWritable writable = (NDArrayWritable) recordIter.next();
labelAssertions[count] = recordIter.next().toInt();
count++;
}
assertArrayEquals(new int[] {0, 1, 2}, labelAssertions);
assertEquals(3, reader.getLabels().size());
assertEquals(3, count);
}
@Test
public void testRecordMetaData() throws Exception {
TfidfVectorizer vectorizer = new TfidfVectorizer();
Configuration conf = new Configuration();
conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1);
conf.setBoolean(RecordReader.APPEND_LABEL, true);
vectorizer.initialize(conf);
TfidfRecordReader reader = new TfidfRecordReader();
File f = testDir.newFolder();
new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f);
reader.initialize(conf, new FileSplit(f));
while (reader.hasNext()) {
Record record = reader.nextRecord();
assertNotNull(record.getMetaData().getURI());
assertEquals(record.getMetaData().getReaderClass(), TfidfRecordReader.class);
}
}
@Test
public void testReadRecordFromMetaData() throws Exception {
TfidfVectorizer vectorizer = new TfidfVectorizer();
Configuration conf = new Configuration();
conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1);
conf.setBoolean(RecordReader.APPEND_LABEL, true);
vectorizer.initialize(conf);
TfidfRecordReader reader = new TfidfRecordReader();
File f = testDir.newFolder();
new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f);
reader.initialize(conf, new FileSplit(f));
Record record = reader.nextRecord();
Record reread = reader.loadFromMetaData(record.getMetaData());
assertEquals(record.getRecord().size(), 2);
assertEquals(reread.getRecord().size(), 2);
assertEquals(record.getRecord().get(0), reread.getRecord().get(0));
assertEquals(record.getRecord().get(1), reread.getRecord().get(1));
assertEquals(record.getMetaData(), reread.getMetaData());
}
}

View File

@ -1,96 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class TestGazeteerTransform {
@Test
public void testGazeteerTransform(){
String[] corpus = {
"hello I like apple".toLowerCase(),
"cherry date eggplant potato".toLowerCase()
};
//Gazeteer transform: basically 0/1 if word is present. Assumes already tokenized input
List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant");
GazeteerTransform t = new GazeteerTransform("words", "out", words);
SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder()
.addColumnString("words").build();
TransformProcess tp = new TransformProcess.Builder(schema)
.transform(t)
.build();
List<List<List<Writable>>> input = new ArrayList<>();
for(String s : corpus){
String[] split = s.split(" ");
List<List<Writable>> seq = new ArrayList<>();
for(String s2 : split){
seq.add(Collections.<Writable>singletonList(new Text(s2)));
}
input.add(seq);
}
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp);
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0});
INDArray exp1 = Nd4j.create(new float[]{0, 0, 1, 1, 1});
assertEquals(exp0, arr0);
assertEquals(exp1, arr1);
String json = tp.toJson();
TransformProcess tp2 = TransformProcess.fromJson(json);
assertEquals(tp, tp2);
List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp);
INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get();
INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get();
assertEquals(exp0, arr0a);
assertEquals(exp1, arr1a);
}
}

View File

@ -1,96 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import static org.junit.Assert.assertEquals;
public class TestMultiNLPTransform {
@Test
public void test(){
List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant");
GazeteerTransform t1 = new GazeteerTransform("words", "out", words);
GazeteerTransform t2 = new GazeteerTransform("out", "out", words);
MultiNlpTransform multi = new MultiNlpTransform("text", new BagOfWordsTransform[]{t1, t2}, "out");
String[] corpus = {
"hello I like apple".toLowerCase(),
"date eggplant potato".toLowerCase()
};
List<List<List<Writable>>> input = new ArrayList<>();
for(String s : corpus){
String[] split = s.split(" ");
List<List<Writable>> seq = new ArrayList<>();
for(String s2 : split){
seq.add(Collections.<Writable>singletonList(new Text(s2)));
}
input.add(seq);
}
SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder()
.addColumnString("text").build();
TransformProcess tp = new TransformProcess.Builder(schema)
.transform(multi)
.build();
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp);
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0, 1, 0, 0, 0, 0});
INDArray exp1 = Nd4j.create(new float[]{0, 0, 0, 1, 1, 0, 0, 0, 1, 1});
assertEquals(exp0, arr0);
assertEquals(exp1, arr1);
String json = tp.toJson();
TransformProcess tp2 = TransformProcess.fromJson(json);
assertEquals(tp, tp2);
List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp);
INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get();
INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get();
assertEquals(exp0, arr0a);
assertEquals(exp1, arr1a);
}
}

View File

@ -1,414 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.nlp.transforms;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.nlp.metadata.VocabCache;
import org.datavec.nlp.tokenization.tokenizer.preprocessor.LowerCasePreProcessor;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.vectorizer.TfidfVectorizer;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Triple;
import java.util.*;
import static org.datavec.nlp.vectorizer.TextVectorizer.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class TokenizerBagOfWordsTermSequenceIndexTransformTest {
@Test
public void testSequenceExecution() {
//credit: https://stackoverflow.com/questions/23792781/tf-idf-feature-weights-using-sklearn-feature-extraction-text-tfidfvectorizer
String[] corpus = {
"This is very strange".toLowerCase(),
"This is very nice".toLowerCase()
};
//{'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0}
/*
## Reproduce with:
from sklearn.feature_extraction.text import TfidfVectorizer
corpus = ["This is very strange", "This is very nice"]
## SMOOTH = FALSE case:
vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False)
X = vectorizer.fit_transform(corpus)
idf = vectorizer.idf_
print(dict(zip(vectorizer.get_feature_names(), idf)))
newText = ["This is very strange", "This is very nice"]
out = vectorizer.transform(newText)
print(out)
{'is': 1.0, 'nice': 1.6931471805599454, 'strange': 1.6931471805599454, 'this': 1.0, 'very': 1.0}
(0, 4) 1.0
(0, 3) 1.0
(0, 2) 1.6931471805599454
(0, 0) 1.0
(1, 4) 1.0
(1, 3) 1.0
(1, 1) 1.6931471805599454
(1, 0) 1.0
## SMOOTH + TRUE case:
{'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0}
(0, 4) 1.0
(0, 3) 1.0
(0, 2) 1.4054651081081644
(0, 0) 1.0
(1, 4) 1.0
(1, 3) 1.0
(1, 1) 1.4054651081081644
(1, 0) 1.0
*/
List<List<List<Writable>>> input = new ArrayList<>();
input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1]))));
// First: Check TfidfVectorizer vs. scikit:
Map<String,Double> idfMapNoSmooth = new HashMap<>();
idfMapNoSmooth.put("is",1.0);
idfMapNoSmooth.put("nice",1.6931471805599454);
idfMapNoSmooth.put("strange",1.6931471805599454);
idfMapNoSmooth.put("this",1.0);
idfMapNoSmooth.put("very",1.0);
Map<String,Double> idfMapSmooth = new HashMap<>();
idfMapSmooth.put("is",1.0);
idfMapSmooth.put("nice",1.4054651081081644);
idfMapSmooth.put("strange",1.4054651081081644);
idfMapSmooth.put("this",1.0);
idfMapSmooth.put("very",1.0);
TfidfVectorizer tfidfVectorizer = new TfidfVectorizer();
Configuration configuration = new Configuration();
configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName());
configuration.set(MIN_WORD_FREQUENCY,"1");
configuration.set(STOP_WORDS,"");
configuration.set(TfidfVectorizer.SMOOTH_IDF, "false");
tfidfVectorizer.initialize(configuration);
CollectionRecordReader collectionRecordReader = new CollectionRecordReader(input.get(0));
INDArray array = tfidfVectorizer.fitTransform(collectionRecordReader);
INDArray expNoSmooth = Nd4j.create(DataType.FLOAT, 2, 5);
VocabCache vc = tfidfVectorizer.getCache();
expNoSmooth.putScalar(0, vc.wordIndex("very"), 1.0);
expNoSmooth.putScalar(0, vc.wordIndex("this"), 1.0);
expNoSmooth.putScalar(0, vc.wordIndex("strange"), 1.6931471805599454);
expNoSmooth.putScalar(0, vc.wordIndex("is"), 1.0);
expNoSmooth.putScalar(1, vc.wordIndex("very"), 1.0);
expNoSmooth.putScalar(1, vc.wordIndex("this"), 1.0);
expNoSmooth.putScalar(1, vc.wordIndex("nice"), 1.6931471805599454);
expNoSmooth.putScalar(1, vc.wordIndex("is"), 1.0);
assertEquals(expNoSmooth, array);
//------------------------------------------------------------
//Smooth version:
tfidfVectorizer = new TfidfVectorizer();
configuration = new Configuration();
configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName());
configuration.set(MIN_WORD_FREQUENCY,"1");
configuration.set(STOP_WORDS,"");
configuration.set(TfidfVectorizer.SMOOTH_IDF, "true");
tfidfVectorizer.initialize(configuration);
collectionRecordReader.reset();
array = tfidfVectorizer.fitTransform(collectionRecordReader);
INDArray expSmooth = Nd4j.create(DataType.FLOAT, 2, 5);
expSmooth.putScalar(0, vc.wordIndex("very"), 1.0);
expSmooth.putScalar(0, vc.wordIndex("this"), 1.0);
expSmooth.putScalar(0, vc.wordIndex("strange"), 1.4054651081081644);
expSmooth.putScalar(0, vc.wordIndex("is"), 1.0);
expSmooth.putScalar(1, vc.wordIndex("very"), 1.0);
expSmooth.putScalar(1, vc.wordIndex("this"), 1.0);
expSmooth.putScalar(1, vc.wordIndex("nice"), 1.4054651081081644);
expSmooth.putScalar(1, vc.wordIndex("is"), 1.0);
assertEquals(expSmooth, array);
//////////////////////////////////////////////////////////
//Second: Check transform vs scikit/TfidfVectorizer
List<String> vocab = new ArrayList<>(5); //Arrays.asList("is","nice","strange","this","very");
for( int i=0; i<5; i++ ){
vocab.add(vc.wordAt(i));
}
String inputColumnName = "input";
String outputColumnName = "output";
Map<String,Integer> wordIndexMap = new HashMap<>();
for(int i = 0; i < vocab.size(); i++) {
wordIndexMap.put(vocab.get(i),i);
}
TokenizerBagOfWordsTermSequenceIndexTransform tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform(
inputColumnName,
outputColumnName,
wordIndexMap,
idfMapNoSmooth,
false,
null, null);
SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder();
sequenceSchemaBuilder.addColumnString("input");
SequenceSchema schema = sequenceSchemaBuilder.build();
assertEquals("input",schema.getName(0));
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform(tokenizerBagOfWordsTermSequenceIndexTransform)
.build();
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess);
//System.out.println(execute);
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
assertEquals(expNoSmooth.getRow(0, true), arr0);
assertEquals(expNoSmooth.getRow(1, true), arr1);
//--------------------------------
//Check smooth:
tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform(
inputColumnName,
outputColumnName,
wordIndexMap,
idfMapSmooth,
false,
null, null);
schema = (SequenceSchema) new SequenceSchema.Builder().addColumnString("input").build();
transformProcess = new TransformProcess.Builder(schema)
.transform(tokenizerBagOfWordsTermSequenceIndexTransform)
.build();
execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess);
arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
assertEquals(expSmooth.getRow(0, true), arr0);
assertEquals(expSmooth.getRow(1, true), arr1);
//Test JSON serialization:
String json = transformProcess.toJson();
TransformProcess fromJson = TransformProcess.fromJson(json);
assertEquals(transformProcess, fromJson);
List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, fromJson);
INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get();
INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get();
assertEquals(expSmooth.getRow(0, true), arr0a);
assertEquals(expSmooth.getRow(1, true), arr1a);
}
@Test
public void additionalTest(){
/*
## To reproduce:
from sklearn.feature_extraction.text import TfidfVectorizer
corpus = [
'This is the first document',
'This document is the second document',
'And this is the third one',
'Is this the first document',
]
vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False)
X = vectorizer.fit_transform(corpus)
print(vectorizer.get_feature_names())
out = vectorizer.transform(corpus)
print(out)
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
(0, 8) 1.0
(0, 6) 1.0
(0, 3) 1.0
(0, 2) 1.6931471805599454
(0, 1) 1.2876820724517808
(1, 8) 1.0
(1, 6) 1.0
(1, 5) 2.386294361119891
(1, 3) 1.0
(1, 1) 2.5753641449035616
(2, 8) 1.0
(2, 7) 2.386294361119891
(2, 6) 1.0
(2, 4) 2.386294361119891
(2, 3) 1.0
(2, 0) 2.386294361119891
(3, 8) 1.0
(3, 6) 1.0
(3, 3) 1.0
(3, 2) 1.6931471805599454
(3, 1) 1.2876820724517808
{'and': 2.386294361119891, 'document': 1.2876820724517808, 'first': 1.6931471805599454, 'is': 1.0, 'one': 2.386294361119891, 'second': 2.386294361119891, 'the': 1.0, 'third': 2.386294361119891, 'this': 1.0}
*/
String[] corpus = {
"This is the first document",
"This document is the second document",
"And this is the third one",
"Is this the first document"};
TfidfVectorizer tfidfVectorizer = new TfidfVectorizer();
Configuration configuration = new Configuration();
configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName());
configuration.set(MIN_WORD_FREQUENCY,"1");
configuration.set(STOP_WORDS,"");
configuration.set(TfidfVectorizer.SMOOTH_IDF, "false");
configuration.set(PREPROCESSOR, LowerCasePreProcessor.class.getName());
tfidfVectorizer.initialize(configuration);
List<List<List<Writable>>> input = new ArrayList<>();
//input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1]))));
List<List<Writable>> seq = new ArrayList<>();
for(String s : corpus){
seq.add(Collections.<Writable>singletonList(new Text(s)));
}
input.add(seq);
CollectionRecordReader crr = new CollectionRecordReader(seq);
INDArray arr = tfidfVectorizer.fitTransform(crr);
//System.out.println(arr);
assertArrayEquals(new long[]{4, 9}, arr.shape());
List<String> pyVocab = Arrays.asList("and", "document", "first", "is", "one", "second", "the", "third", "this");
List<Triple<Integer,Integer,Double>> l = new ArrayList<>();
l.add(new Triple<>(0, 8, 1.0));
l.add(new Triple<>(0, 6, 1.0));
l.add(new Triple<>(0, 3, 1.0));
l.add(new Triple<>(0, 2, 1.6931471805599454));
l.add(new Triple<>(0, 1, 1.2876820724517808));
l.add(new Triple<>(1, 8, 1.0));
l.add(new Triple<>(1, 6, 1.0));
l.add(new Triple<>(1, 5, 2.386294361119891));
l.add(new Triple<>(1, 3, 1.0));
l.add(new Triple<>(1, 1, 2.5753641449035616));
l.add(new Triple<>(2, 8, 1.0));
l.add(new Triple<>(2, 7, 2.386294361119891));
l.add(new Triple<>(2, 6, 1.0));
l.add(new Triple<>(2, 4, 2.386294361119891));
l.add(new Triple<>(2, 3, 1.0));
l.add(new Triple<>(2, 0, 2.386294361119891));
l.add(new Triple<>(3, 8, 1.0));
l.add(new Triple<>(3, 6, 1.0));
l.add(new Triple<>(3, 3, 1.0));
l.add(new Triple<>(3, 2, 1.6931471805599454));
l.add(new Triple<>(3, 1, 1.2876820724517808));
INDArray exp = Nd4j.create(DataType.FLOAT, 4, 9);
for(Triple<Integer,Integer,Double> t : l){
//Work out work index, accounting for different vocab/word orders:
int wIdx = tfidfVectorizer.getCache().wordIndex(pyVocab.get(t.getSecond()));
exp.putScalar(t.getFirst(), wIdx, t.getThird());
}
assertEquals(exp, arr);
Map<String,Double> idfWeights = new HashMap<>();
idfWeights.put("and", 2.386294361119891);
idfWeights.put("document", 1.2876820724517808);
idfWeights.put("first", 1.6931471805599454);
idfWeights.put("is", 1.0);
idfWeights.put("one", 2.386294361119891);
idfWeights.put("second", 2.386294361119891);
idfWeights.put("the", 1.0);
idfWeights.put("third", 2.386294361119891);
idfWeights.put("this", 1.0);
List<String> vocab = new ArrayList<>(9); //Arrays.asList("is","nice","strange","this","very");
for( int i=0; i<9; i++ ){
vocab.add(tfidfVectorizer.getCache().wordAt(i));
}
String inputColumnName = "input";
String outputColumnName = "output";
Map<String,Integer> wordIndexMap = new HashMap<>();
for(int i = 0; i < vocab.size(); i++) {
wordIndexMap.put(vocab.get(i),i);
}
TokenizerBagOfWordsTermSequenceIndexTransform transform = new TokenizerBagOfWordsTermSequenceIndexTransform(
inputColumnName,
outputColumnName,
wordIndexMap,
idfWeights,
false,
null, LowerCasePreProcessor.class.getName());
SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder();
sequenceSchemaBuilder.addColumnString("input");
SequenceSchema schema = sequenceSchemaBuilder.build();
assertEquals("input",schema.getName(0));
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform(transform)
.build();
List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess);
INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get();
INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get();
assertEquals(exp.getRow(0, true), arr0);
assertEquals(exp.getRow(1, true), arr1);
}
}

View File

@ -1,53 +0,0 @@
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.apache.catalina.core" level="DEBUG" />
<logger name="org.springframework" level="DEBUG" />
<logger name="org.datavec" level="DEBUG" />
<logger name="org.nd4j" level="INFO" />
<logger name="opennlp.uima.util" level="OFF" />
<logger name="org.apache.uima" level="OFF" />
<logger name="org.cleartk" level="OFF" />
<logger name="org.apache.spark" level="WARN" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -1,56 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-data</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-geo</artifactId>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
</dependency>
<dependency>
<groupId>com.maxmind.geoip2</groupId>
<artifactId>geoip2</artifactId>
<version>${geoip2.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,25 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.geo;
public enum LocationType {
CITY, CITY_ID, CONTINENT, CONTINENT_ID, COUNTRY, COUNTRY_ID, COORDINATES, POSTAL_CODE, SUBDIVISIONS, SUBDIVISIONS_ID
}

View File

@ -1,194 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.reduce.geo;
import lombok.Getter;
import org.datavec.api.transform.ReduceOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.StringMetaData;
import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.AggregableColumnReduction;
import org.datavec.api.transform.reduce.AggregableReductionUtils;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.nd4j.common.function.Supplier;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class CoordinatesReduction implements AggregableColumnReduction {
public static final String DEFAULT_COLUMN_NAME = "CoordinatesReduction";
public final static String DEFAULT_DELIMITER = ":";
protected String delimiter = DEFAULT_DELIMITER;
private final List<String> columnNamesPostReduce;
private final Supplier<IAggregableReduceOp<Writable, List<Writable>>> multiOp(final List<ReduceOp> ops) {
return new Supplier<IAggregableReduceOp<Writable, List<Writable>>>() {
@Override
public IAggregableReduceOp<Writable, List<Writable>> get() {
return AggregableReductionUtils.reduceDoubleColumn(ops, false, null);
}
};
}
public CoordinatesReduction(String columnNamePostReduce, ReduceOp op) {
this(columnNamePostReduce, op, DEFAULT_DELIMITER);
}
public CoordinatesReduction(List<String> columnNamePostReduce, List<ReduceOp> op) {
this(columnNamePostReduce, op, DEFAULT_DELIMITER);
}
public CoordinatesReduction(String columnNamePostReduce, ReduceOp op, String delimiter) {
this(Collections.singletonList(columnNamePostReduce), Collections.singletonList(op), delimiter);
}
public CoordinatesReduction(List<String> columnNamesPostReduce, List<ReduceOp> ops, String delimiter) {
this.columnNamesPostReduce = columnNamesPostReduce;
this.reducer = new CoordinateAggregableReduceOp(ops.size(), multiOp(ops), delimiter);
}
@Override
public List<String> getColumnsOutputName(String columnInputName) {
return columnNamesPostReduce;
}
@Override
public List<ColumnMetaData> getColumnOutputMetaData(List<String> newColumnName, ColumnMetaData columnInputMeta) {
List<ColumnMetaData> res = new ArrayList<>(newColumnName.size());
for (String cn : newColumnName)
res.add(new StringMetaData((cn)));
return res;
}
@Override
public Schema transform(Schema inputSchema) {
throw new UnsupportedOperationException();
}
@Override
public void setInputSchema(Schema inputSchema) {
throw new UnsupportedOperationException();
}
@Override
public Schema getInputSchema() {
throw new UnsupportedOperationException();
}
@Override
public String outputColumnName() {
throw new UnsupportedOperationException();
}
@Override
public String[] outputColumnNames() {
throw new UnsupportedOperationException();
}
@Override
public String[] columnNames() {
throw new UnsupportedOperationException();
}
@Override
public String columnName() {
throw new UnsupportedOperationException();
}
private IAggregableReduceOp<Writable, List<Writable>> reducer;
@Override
public IAggregableReduceOp<Writable, List<Writable>> reduceOp() {
return reducer;
}
public static class CoordinateAggregableReduceOp implements IAggregableReduceOp<Writable, List<Writable>> {
private int nOps;
private Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOpValue;
@Getter
private ArrayList<IAggregableReduceOp<Writable, List<Writable>>> perCoordinateOps; // of size coords()
private String delimiter;
public CoordinateAggregableReduceOp(int n, Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOp,
String delim) {
this.nOps = n;
this.perCoordinateOps = new ArrayList<>();
this.initialOpValue = initialOp;
this.delimiter = delim;
}
@Override
public <W extends IAggregableReduceOp<Writable, List<Writable>>> void combine(W accu) {
if (accu instanceof CoordinateAggregableReduceOp) {
CoordinateAggregableReduceOp accumulator = (CoordinateAggregableReduceOp) accu;
for (int i = 0; i < Math.min(perCoordinateOps.size(), accumulator.getPerCoordinateOps().size()); i++) {
perCoordinateOps.get(i).combine(accumulator.getPerCoordinateOps().get(i));
} // the rest is assumed identical
}
}
@Override
public void accept(Writable writable) {
String[] coordinates = writable.toString().split(delimiter);
for (int i = 0; i < coordinates.length; i++) {
String coordinate = coordinates[i];
while (perCoordinateOps.size() < i + 1) {
perCoordinateOps.add(initialOpValue.get());
}
perCoordinateOps.get(i).accept(new DoubleWritable(Double.parseDouble(coordinate)));
}
}
@Override
public List<Writable> get() {
List<StringBuilder> res = new ArrayList<>(nOps);
for (int i = 0; i < nOps; i++) {
res.add(new StringBuilder());
}
for (int i = 0; i < perCoordinateOps.size(); i++) {
List<Writable> resThisCoord = perCoordinateOps.get(i).get();
for (int j = 0; j < nOps; j++) {
res.get(j).append(resThisCoord.get(j).toString());
if (i < perCoordinateOps.size() - 1) {
res.get(j).append(delimiter);
}
}
}
List<Writable> finalRes = new ArrayList<>(nOps);
for (StringBuilder sb : res) {
finalRes.add(new Text(sb.toString()));
}
return finalRes;
}
}
}

View File

@ -1,117 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.transform.geo;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.DoubleMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseColumnsMathOpTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class CoordinatesDistanceTransform extends BaseColumnsMathOpTransform {
public final static String DEFAULT_DELIMITER = ":";
protected String delimiter = DEFAULT_DELIMITER;
public CoordinatesDistanceTransform(String newColumnName, String firstColumn, String secondColumn,
String stdevColumn) {
this(newColumnName, firstColumn, secondColumn, stdevColumn, DEFAULT_DELIMITER);
}
public CoordinatesDistanceTransform(@JsonProperty("newColumnName") String newColumnName,
@JsonProperty("firstColumn") String firstColumn, @JsonProperty("secondColumn") String secondColumn,
@JsonProperty("stdevColumn") String stdevColumn, @JsonProperty("delimiter") String delimiter) {
super(newColumnName, MathOp.Add /* dummy op */,
stdevColumn != null ? new String[] {firstColumn, secondColumn, stdevColumn}
: new String[] {firstColumn, secondColumn});
this.delimiter = delimiter;
}
@Override
protected ColumnMetaData derivedColumnMetaData(String newColumnName, Schema inputSchema) {
return new DoubleMetaData(newColumnName);
}
@Override
protected Writable doOp(Writable... input) {
String[] first = input[0].toString().split(delimiter);
String[] second = input[1].toString().split(delimiter);
String[] stdev = columns.length > 2 ? input[2].toString().split(delimiter) : null;
double dist = 0;
for (int i = 0; i < first.length; i++) {
double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]);
double s = stdev != null ? Double.parseDouble(stdev[i]) : 1;
dist += (d * d) / (s * s);
}
return new DoubleWritable(Math.sqrt(dist));
}
@Override
public String toString() {
return "CoordinatesDistanceTransform(newColumnName=\"" + newColumnName + "\",columns="
+ Arrays.toString(columns) + ",delimiter=" + delimiter + ")";
}
/**
* Transform an object
* in to another object
*
* @param input the record to transform
* @return the transformed writable
*/
@Override
public Object map(Object input) {
List row = (List) input;
String[] first = row.get(0).toString().split(delimiter);
String[] second = row.get(1).toString().split(delimiter);
String[] stdev = columns.length > 2 ? row.get(2).toString().split(delimiter) : null;
double dist = 0;
for (int i = 0; i < first.length; i++) {
double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]);
double s = stdev != null ? Double.parseDouble(stdev[i]) : 1;
dist += (d * d) / (s * s);
}
return Math.sqrt(dist);
}
/**
* Transform a sequence
*
* @param sequence
*/
@Override
public Object mapSequence(Object sequence) {
List<List> seq = (List<List>) sequence;
List<Double> ret = new ArrayList<>();
for (Object step : seq)
ret.add((Double) map(step));
return ret;
}
}

View File

@ -1,70 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.transform.geo;
import org.apache.commons.io.FileUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArchiveUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.net.URL;
public class GeoIPFetcher {
protected static final Logger log = LoggerFactory.getLogger(GeoIPFetcher.class);
/** Default directory for <a href="http://dev.maxmind.com/geoip/geoipupdate/">http://dev.maxmind.com/geoip/geoipupdate/</a> */
public static final String GEOIP_DIR = "/usr/local/share/GeoIP/";
public static final String GEOIP_DIR2 = System.getProperty("user.home") + "/.datavec-geoip";
public static final String CITY_DB = "GeoIP2-City.mmdb";
public static final String CITY_LITE_DB = "GeoLite2-City.mmdb";
public static final String CITY_LITE_URL =
"http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz";
public static synchronized File fetchCityDB() throws IOException {
File cityFile = new File(GEOIP_DIR, CITY_DB);
if (cityFile.isFile()) {
return cityFile;
}
cityFile = new File(GEOIP_DIR, CITY_LITE_DB);
if (cityFile.isFile()) {
return cityFile;
}
cityFile = new File(GEOIP_DIR2, CITY_LITE_DB);
if (cityFile.isFile()) {
return cityFile;
}
log.info("Downloading GeoLite2 City database...");
File archive = new File(GEOIP_DIR2, CITY_LITE_DB + ".gz");
File dir = new File(GEOIP_DIR2);
dir.mkdirs();
FileUtils.copyURLToFile(new URL(CITY_LITE_URL), archive);
ArchiveUtils.unzipFileTo(archive.getAbsolutePath(), dir.getAbsolutePath());
Preconditions.checkState(cityFile.isFile(), "Error extracting files: expected city file does not exist after extraction");
return cityFile;
}
}

View File

@ -1,43 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.transform.geo;
import org.datavec.api.transform.geo.LocationType;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.io.IOException;
public class IPAddressToCoordinatesTransform extends IPAddressToLocationTransform {
public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName) throws IOException {
this(columnName, DEFAULT_DELIMITER);
}
public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("delimiter") String delimiter) throws IOException {
super(columnName, LocationType.COORDINATES, delimiter);
}
@Override
public String toString() {
return "IPAddressToCoordinatesTransform";
}
}

View File

@ -1,184 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.transform.geo;
import com.maxmind.geoip2.DatabaseReader;
import com.maxmind.geoip2.exception.GeoIp2Exception;
import com.maxmind.geoip2.model.CityResponse;
import com.maxmind.geoip2.record.Location;
import com.maxmind.geoip2.record.Subdivision;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.transform.geo.LocationType;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.StringMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.InetAddress;
@Slf4j
public class IPAddressToLocationTransform extends BaseColumnTransform {
/**
* Name of the system property to use when configuring the GeoIP database file.<br>
* Most users don't need to set this - typically used for testing purposes.<br>
* Set with the full local path, like: "C:/datavec-geo/GeoIP2-City-Test.mmdb"
*/
public static final String GEOIP_FILE_PROPERTY = "org.datavec.geoip.file";
private static File database;
private static DatabaseReader reader;
public final static String DEFAULT_DELIMITER = ":";
protected String delimiter = DEFAULT_DELIMITER;
protected LocationType locationType;
private static synchronized void init() throws IOException {
// A File object pointing to your GeoIP2 or GeoLite2 database:
// http://dev.maxmind.com/geoip/geoip2/geolite2/
if (database == null) {
String s = System.getProperty(GEOIP_FILE_PROPERTY);
if(s != null && !s.isEmpty()){
//Use user-specified GEOIP file - mainly for testing purposes
File f = new File(s);
if(f.exists() && f.isFile()){
database = f;
} else {
log.warn("GeoIP file (system property {}) is set to \"{}\" but this is not a valid file, using default database", GEOIP_FILE_PROPERTY, s);
database = GeoIPFetcher.fetchCityDB();
}
} else {
database = GeoIPFetcher.fetchCityDB();
}
}
// This creates the DatabaseReader object, which should be reused across lookups.
if (reader == null) {
reader = new DatabaseReader.Builder(database).build();
}
}
public IPAddressToLocationTransform(String columnName) throws IOException {
this(columnName, LocationType.CITY);
}
public IPAddressToLocationTransform(String columnName, LocationType locationType) throws IOException {
this(columnName, locationType, DEFAULT_DELIMITER);
}
public IPAddressToLocationTransform(@JsonProperty("columnName") String columnName,
@JsonProperty("delimiter") LocationType locationType, @JsonProperty("delimiter") String delimiter)
throws IOException {
super(columnName);
this.delimiter = delimiter;
this.locationType = locationType;
init();
}
@Override
public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
return new StringMetaData(newName); //Output after transform: String (Text)
}
@Override
public Writable map(Writable columnWritable) {
try {
InetAddress ipAddress = InetAddress.getByName(columnWritable.toString());
CityResponse response = reader.city(ipAddress);
String text = "";
switch (locationType) {
case CITY:
text = response.getCity().getName();
break;
case CITY_ID:
text = response.getCity().getGeoNameId().toString();
break;
case CONTINENT:
text = response.getContinent().getName();
break;
case CONTINENT_ID:
text = response.getContinent().getGeoNameId().toString();
break;
case COUNTRY:
text = response.getCountry().getName();
break;
case COUNTRY_ID:
text = response.getCountry().getGeoNameId().toString();
break;
case COORDINATES:
Location location = response.getLocation();
text = location.getLatitude() + delimiter + location.getLongitude();
break;
case POSTAL_CODE:
text = response.getPostal().getCode();
break;
case SUBDIVISIONS:
for (Subdivision s : response.getSubdivisions()) {
if (text.length() > 0) {
text += delimiter;
}
text += s.getName();
}
break;
case SUBDIVISIONS_ID:
for (Subdivision s : response.getSubdivisions()) {
if (text.length() > 0) {
text += delimiter;
}
text += s.getGeoNameId().toString();
}
break;
default:
assert false;
}
if(text == null)
text = "";
return new Text(text);
} catch (GeoIp2Exception | IOException e) {
throw new RuntimeException(e);
}
}
@Override
public String toString() {
return "IPAddressToLocationTransform";
}
//Custom serialization methods, because GeoIP2 doesn't allow DatabaseReader objects to be serialized :(
private void writeObject(ObjectOutputStream out) throws IOException {
out.defaultWriteObject();
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
init();
}
@Override
public Object map(Object input) {
return null;
}
}

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.api.transform";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,80 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.reduce;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.ReduceOp;
import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.geo.CoordinatesReduction;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* @author saudet
*/
public class TestGeoReduction {
@Test
public void testCustomReductions() {
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1#5")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2#6")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3#7")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4#8")));
List<Writable> expected = Arrays.asList((Writable) new Text("someKey"), new Text("10.0#26.0"));
Schema schema = new Schema.Builder().addColumnString("key").addColumnString("coord").build();
Reducer reducer = new Reducer.Builder(ReduceOp.Count).keyColumns("key")
.customReduction("coord", new CoordinatesReduction("coordSum", ReduceOp.Sum, "#")).build();
reducer.setInputSchema(schema);
IAggregableReduceOp<List<Writable>, List<Writable>> aggregableReduceOp = reducer.aggregableReducer();
for (List<Writable> l : inputs)
aggregableReduceOp.accept(l);
List<Writable> out = aggregableReduceOp.get();
assertEquals(2, out.size());
assertEquals(expected, out);
//Check schema:
String[] expNames = new String[] {"key", "coordSum"};
ColumnType[] expTypes = new ColumnType[] {ColumnType.String, ColumnType.String};
Schema outSchema = reducer.transform(schema);
assertEquals(2, outSchema.numColumns());
for (int i = 0; i < 2; i++) {
assertEquals(expNames[i], outSchema.getName(i));
assertEquals(expTypes[i], outSchema.getType(i));
}
}
}

View File

@ -1,153 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.transform;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.geo.LocationType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* @author saudet
*/
public class TestGeoTransforms {
@BeforeClass
public static void beforeClass() throws Exception {
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
}
@AfterClass
public static void afterClass(){
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
}
@Test
public void testCoordinatesDistanceTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
.build();
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
new DoubleWritable(Math.sqrt(160))),
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
new Text("10|5"))));
}
@Test
public void testIPAddressToCoordinatesTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
String in = "81.2.69.160";
double latitude = 51.5142;
double longitude = -0.0931;
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(transform);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais);
Transform deserialized = (Transform) ois.readObject();
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
//System.out.println(Arrays.toString(coordinates));
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
}
@Test
public void testIPAddressToLocationTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
LocationType[] locationTypes = LocationType.values();
String in = "81.2.69.160";
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
for (int i = 0; i < locationTypes.length; i++) {
LocationType locationType = locationTypes[i];
String location = locations[i];
Transform transform = new IPAddressToLocationTransform("column", locationType);
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
assertEquals(location, writables.get(0).toString());
//System.out.println(location);
}
}
}

View File

@ -37,11 +37,7 @@
<name>datavec-data</name> <name>datavec-data</name>
<modules> <modules>
<module>datavec-data-audio</module>
<module>datavec-data-codec</module>
<module>datavec-data-image</module> <module>datavec-data-image</module>
<module>datavec-data-nlp</module>
<module>datavec-geo</module>
</modules> </modules>
<dependencyManagement> <dependencyManagement>