Merge branch 'master' of https://github.com/eclipse/deeplearning4j into ag_github_workflows_1
This commit is contained in:
		
						commit
						4a06f39085
					
				| @ -1,77 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-data</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-data-audio</artifactId> | ||||
| 
 | ||||
|     <name>datavec-data-audio</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-api</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.bytedeco</groupId> | ||||
|             <artifactId>javacpp</artifactId> | ||||
|             <version>${javacpp.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.bytedeco</groupId> | ||||
|             <artifactId>javacv</artifactId> | ||||
|             <version>${javacv.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.github.wendykierp</groupId> | ||||
|             <artifactId>JTransforms</artifactId> | ||||
|             <version>${jtransforms.version}</version> | ||||
|             <classifier>with-dependencies</classifier> | ||||
|         </dependency> | ||||
|         <!-- Do not depend on FFmpeg by default due to licensing concerns. --> | ||||
|         <!-- | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>ffmpeg-platform</artifactId> | ||||
|                     <version>${ffmpeg.version}-${javacpp-presets.version}</version> | ||||
|                 </dependency> | ||||
|         --> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,329 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.audio.extension.NormalizedSampleAmplitudes; | ||||
| import org.datavec.audio.extension.Spectrogram; | ||||
| import org.datavec.audio.fingerprint.FingerprintManager; | ||||
| import org.datavec.audio.fingerprint.FingerprintSimilarity; | ||||
| import org.datavec.audio.fingerprint.FingerprintSimilarityComputer; | ||||
| 
 | ||||
| import java.io.FileInputStream; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| public class Wave implements Serializable { | ||||
| 
 | ||||
|     private static final long serialVersionUID = 1L; | ||||
|     private WaveHeader waveHeader; | ||||
|     private byte[] data; // little endian | ||||
|     private byte[] fingerprint; | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      *  | ||||
|      */ | ||||
|     public Wave() { | ||||
|         this.waveHeader = new WaveHeader(); | ||||
|         this.data = new byte[0]; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      *  | ||||
|      * @param filename | ||||
|      *            Wave file | ||||
|      */ | ||||
|     public Wave(String filename) { | ||||
|         try { | ||||
|             InputStream inputStream = new FileInputStream(filename); | ||||
|             initWaveWithInputStream(inputStream); | ||||
|             inputStream.close(); | ||||
|         } catch (IOException e) { | ||||
|             System.out.println(e.toString()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      *  | ||||
|      * @param inputStream | ||||
|      *            Wave file input stream | ||||
|      */ | ||||
|     public Wave(InputStream inputStream) { | ||||
|         initWaveWithInputStream(inputStream); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      *  | ||||
|      * @param waveHeader | ||||
|      * @param data | ||||
|      */ | ||||
|     public Wave(WaveHeader waveHeader, byte[] data) { | ||||
|         this.waveHeader = waveHeader; | ||||
|         this.data = data; | ||||
|     } | ||||
| 
 | ||||
|     private void initWaveWithInputStream(InputStream inputStream) { | ||||
|         // reads the first 44 bytes for header | ||||
|         waveHeader = new WaveHeader(inputStream); | ||||
| 
 | ||||
|         if (waveHeader.isValid()) { | ||||
|             // load data | ||||
|             try { | ||||
|                 data = new byte[inputStream.available()]; | ||||
|                 inputStream.read(data); | ||||
|             } catch (IOException e) { | ||||
|                 System.err.println(e.toString()); | ||||
|             } | ||||
|             // end load data | ||||
|         } else { | ||||
|             System.err.println("Invalid Wave Header"); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Trim the wave data | ||||
|      *  | ||||
|      * @param leftTrimNumberOfSample | ||||
|      *            Number of sample trimmed from beginning | ||||
|      * @param rightTrimNumberOfSample | ||||
|      *            Number of sample trimmed from ending | ||||
|      */ | ||||
|     public void trim(int leftTrimNumberOfSample, int rightTrimNumberOfSample) { | ||||
| 
 | ||||
|         long chunkSize = waveHeader.getChunkSize(); | ||||
|         long subChunk2Size = waveHeader.getSubChunk2Size(); | ||||
| 
 | ||||
|         long totalTrimmed = leftTrimNumberOfSample + rightTrimNumberOfSample; | ||||
| 
 | ||||
|         if (totalTrimmed > subChunk2Size) { | ||||
|             leftTrimNumberOfSample = (int) subChunk2Size; | ||||
|         } | ||||
| 
 | ||||
|         // update wav info | ||||
|         chunkSize -= totalTrimmed; | ||||
|         subChunk2Size -= totalTrimmed; | ||||
| 
 | ||||
|         if (chunkSize >= 0 && subChunk2Size >= 0) { | ||||
|             waveHeader.setChunkSize(chunkSize); | ||||
|             waveHeader.setSubChunk2Size(subChunk2Size); | ||||
| 
 | ||||
|             byte[] trimmedData = new byte[(int) subChunk2Size]; | ||||
|             System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size); | ||||
|             data = trimmedData; | ||||
|         } else { | ||||
|             System.err.println("Trim error: Negative length"); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Trim the wave data from beginning | ||||
|      *  | ||||
|      * @param numberOfSample | ||||
|      *            numberOfSample trimmed from beginning | ||||
|      */ | ||||
|     public void leftTrim(int numberOfSample) { | ||||
|         trim(numberOfSample, 0); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Trim the wave data from ending | ||||
|      *  | ||||
|      * @param numberOfSample | ||||
|      *            numberOfSample trimmed from ending | ||||
|      */ | ||||
|     public void rightTrim(int numberOfSample) { | ||||
|         trim(0, numberOfSample); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Trim the wave data | ||||
|      *  | ||||
|      * @param leftTrimSecond | ||||
|      *            Seconds trimmed from beginning | ||||
|      * @param rightTrimSecond | ||||
|      *            Seconds trimmed from ending | ||||
|      */ | ||||
|     public void trim(double leftTrimSecond, double rightTrimSecond) { | ||||
| 
 | ||||
|         int sampleRate = waveHeader.getSampleRate(); | ||||
|         int bitsPerSample = waveHeader.getBitsPerSample(); | ||||
|         int channels = waveHeader.getChannels(); | ||||
| 
 | ||||
|         int leftTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * leftTrimSecond); | ||||
|         int rightTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * rightTrimSecond); | ||||
| 
 | ||||
|         trim(leftTrimNumberOfSample, rightTrimNumberOfSample); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Trim the wave data from beginning | ||||
|      *  | ||||
|      * @param second | ||||
|      *            Seconds trimmed from beginning | ||||
|      */ | ||||
|     public void leftTrim(double second) { | ||||
|         trim(second, 0); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Trim the wave data from ending | ||||
|      *  | ||||
|      * @param second | ||||
|      *            Seconds trimmed from ending | ||||
|      */ | ||||
|     public void rightTrim(double second) { | ||||
|         trim(0, second); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the wave header | ||||
|      *  | ||||
|      * @return waveHeader | ||||
|      */ | ||||
|     public WaveHeader getWaveHeader() { | ||||
|         return waveHeader; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the wave spectrogram | ||||
|      *  | ||||
|      * @return spectrogram | ||||
|      */ | ||||
|     public Spectrogram getSpectrogram() { | ||||
|         return new Spectrogram(this); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the wave spectrogram | ||||
|      *  | ||||
|      * @param fftSampleSize	number of sample in fft, the value needed to be a number to power of 2 | ||||
|      * @param overlapFactor	1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping | ||||
|      *  | ||||
|      * @return spectrogram | ||||
|      */ | ||||
|     public Spectrogram getSpectrogram(int fftSampleSize, int overlapFactor) { | ||||
|         return new Spectrogram(this, fftSampleSize, overlapFactor); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the wave data in bytes | ||||
|      *  | ||||
|      * @return wave data | ||||
|      */ | ||||
|     public byte[] getBytes() { | ||||
|         return data; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Data byte size of the wave excluding header size | ||||
|      *  | ||||
|      * @return byte size of the wave | ||||
|      */ | ||||
|     public int size() { | ||||
|         return data.length; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Length of the wave in second | ||||
|      *  | ||||
|      * @return length in second | ||||
|      */ | ||||
|     public float length() { | ||||
|         return (float) waveHeader.getSubChunk2Size() / waveHeader.getByteRate(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Timestamp of the wave length | ||||
|      *  | ||||
|      * @return timestamp | ||||
|      */ | ||||
|     public String timestamp() { | ||||
|         float totalSeconds = this.length(); | ||||
|         float second = totalSeconds % 60; | ||||
|         int minute = (int) totalSeconds / 60 % 60; | ||||
|         int hour = (int) (totalSeconds / 3600); | ||||
| 
 | ||||
|         StringBuilder sb = new StringBuilder(); | ||||
|         if (hour > 0) { | ||||
|             sb.append(hour + ":"); | ||||
|         } | ||||
|         if (minute > 0) { | ||||
|             sb.append(minute + ":"); | ||||
|         } | ||||
|         sb.append(second); | ||||
| 
 | ||||
|         return sb.toString(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the amplitudes of the wave samples (depends on the header) | ||||
|      *  | ||||
|      * @return amplitudes array (signed 16-bit) | ||||
|      */ | ||||
|     public short[] getSampleAmplitudes() { | ||||
|         int bytePerSample = waveHeader.getBitsPerSample() / 8; | ||||
|         int numSamples = data.length / bytePerSample; | ||||
|         short[] amplitudes = new short[numSamples]; | ||||
| 
 | ||||
|         int pointer = 0; | ||||
|         for (int i = 0; i < numSamples; i++) { | ||||
|             short amplitude = 0; | ||||
|             for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) { | ||||
|                 // little endian | ||||
|                 amplitude |= (short) ((data[pointer++] & 0xFF) << (byteNumber * 8)); | ||||
|             } | ||||
|             amplitudes[i] = amplitude; | ||||
|         } | ||||
| 
 | ||||
|         return amplitudes; | ||||
|     } | ||||
| 
 | ||||
|     public String toString() { | ||||
|         StringBuilder sb = new StringBuilder(waveHeader.toString()); | ||||
|         sb.append("\n"); | ||||
|         sb.append("length: " + timestamp()); | ||||
|         return sb.toString(); | ||||
|     } | ||||
| 
 | ||||
|     public double[] getNormalizedAmplitudes() { | ||||
|         NormalizedSampleAmplitudes amplitudes = new NormalizedSampleAmplitudes(this); | ||||
|         return amplitudes.getNormalizedAmplitudes(); | ||||
|     } | ||||
| 
 | ||||
|     public byte[] getFingerprint() { | ||||
|         if (fingerprint == null) { | ||||
|             FingerprintManager fingerprintManager = new FingerprintManager(); | ||||
|             fingerprint = fingerprintManager.extractFingerprint(this); | ||||
|         } | ||||
|         return fingerprint; | ||||
|     } | ||||
| 
 | ||||
|     public FingerprintSimilarity getFingerprintSimilarity(Wave wave) { | ||||
|         FingerprintSimilarityComputer fingerprintSimilarityComputer = | ||||
|                         new FingerprintSimilarityComputer(this.getFingerprint(), wave.getFingerprint()); | ||||
|         return fingerprintSimilarityComputer.getFingerprintsSimilarity(); | ||||
|     } | ||||
| } | ||||
| @ -1,98 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| 
 | ||||
| import java.io.FileOutputStream; | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class WaveFileManager { | ||||
| 
 | ||||
|     private Wave wave; | ||||
| 
 | ||||
|     public WaveFileManager() { | ||||
|         wave = new Wave(); | ||||
|     } | ||||
| 
 | ||||
|     public WaveFileManager(Wave wave) { | ||||
|         setWave(wave); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Save the wave file | ||||
|      *  | ||||
|      * @param filename | ||||
|      *            filename to be saved | ||||
|      *             | ||||
|      * @see	Wave file saved | ||||
|      */ | ||||
|     public void saveWaveAsFile(String filename) { | ||||
| 
 | ||||
|         WaveHeader waveHeader = wave.getWaveHeader(); | ||||
| 
 | ||||
|         int byteRate = waveHeader.getByteRate(); | ||||
|         int audioFormat = waveHeader.getAudioFormat(); | ||||
|         int sampleRate = waveHeader.getSampleRate(); | ||||
|         int bitsPerSample = waveHeader.getBitsPerSample(); | ||||
|         int channels = waveHeader.getChannels(); | ||||
|         long chunkSize = waveHeader.getChunkSize(); | ||||
|         long subChunk1Size = waveHeader.getSubChunk1Size(); | ||||
|         long subChunk2Size = waveHeader.getSubChunk2Size(); | ||||
|         int blockAlign = waveHeader.getBlockAlign(); | ||||
| 
 | ||||
|         try { | ||||
|             FileOutputStream fos = new FileOutputStream(filename); | ||||
|             fos.write(WaveHeader.RIFF_HEADER.getBytes()); | ||||
|             // little endian | ||||
|             fos.write(new byte[] {(byte) (chunkSize), (byte) (chunkSize >> 8), (byte) (chunkSize >> 16), | ||||
|                             (byte) (chunkSize >> 24)}); | ||||
|             fos.write(WaveHeader.WAVE_HEADER.getBytes()); | ||||
|             fos.write(WaveHeader.FMT_HEADER.getBytes()); | ||||
|             fos.write(new byte[] {(byte) (subChunk1Size), (byte) (subChunk1Size >> 8), (byte) (subChunk1Size >> 16), | ||||
|                             (byte) (subChunk1Size >> 24)}); | ||||
|             fos.write(new byte[] {(byte) (audioFormat), (byte) (audioFormat >> 8)}); | ||||
|             fos.write(new byte[] {(byte) (channels), (byte) (channels >> 8)}); | ||||
|             fos.write(new byte[] {(byte) (sampleRate), (byte) (sampleRate >> 8), (byte) (sampleRate >> 16), | ||||
|                             (byte) (sampleRate >> 24)}); | ||||
|             fos.write(new byte[] {(byte) (byteRate), (byte) (byteRate >> 8), (byte) (byteRate >> 16), | ||||
|                             (byte) (byteRate >> 24)}); | ||||
|             fos.write(new byte[] {(byte) (blockAlign), (byte) (blockAlign >> 8)}); | ||||
|             fos.write(new byte[] {(byte) (bitsPerSample), (byte) (bitsPerSample >> 8)}); | ||||
|             fos.write(WaveHeader.DATA_HEADER.getBytes()); | ||||
|             fos.write(new byte[] {(byte) (subChunk2Size), (byte) (subChunk2Size >> 8), (byte) (subChunk2Size >> 16), | ||||
|                             (byte) (subChunk2Size >> 24)}); | ||||
|             fos.write(wave.getBytes()); | ||||
|             fos.close(); | ||||
|         } catch (IOException e) { | ||||
|             log.error("",e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public Wave getWave() { | ||||
|         return wave; | ||||
|     } | ||||
| 
 | ||||
|     public void setWave(Wave wave) { | ||||
|         this.wave = wave; | ||||
|     } | ||||
| } | ||||
| @ -1,281 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class WaveHeader { | ||||
| 
 | ||||
|     public static final String RIFF_HEADER = "RIFF"; | ||||
|     public static final String WAVE_HEADER = "WAVE"; | ||||
|     public static final String FMT_HEADER = "fmt "; | ||||
|     public static final String DATA_HEADER = "data"; | ||||
|     public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header | ||||
| 
 | ||||
|     private boolean valid; | ||||
|     private String chunkId; // 4 bytes | ||||
|     private long chunkSize; // unsigned 4 bytes, little endian | ||||
|     private String format; // 4 bytes | ||||
|     private String subChunk1Id; // 4 bytes | ||||
|     private long subChunk1Size; // unsigned 4 bytes, little endian | ||||
|     private int audioFormat; // unsigned 2 bytes, little endian | ||||
|     private int channels; // unsigned 2 bytes, little endian | ||||
|     private long sampleRate; // unsigned 4 bytes, little endian | ||||
|     private long byteRate; // unsigned 4 bytes, little endian | ||||
|     private int blockAlign; // unsigned 2 bytes, little endian | ||||
|     private int bitsPerSample; // unsigned 2 bytes, little endian | ||||
|     private String subChunk2Id; // 4 bytes | ||||
|     private long subChunk2Size; // unsigned 4 bytes, little endian | ||||
| 
 | ||||
|     public WaveHeader() { | ||||
|         // init a 8k 16bit mono wav		 | ||||
|         chunkSize = 36; | ||||
|         subChunk1Size = 16; | ||||
|         audioFormat = 1; | ||||
|         channels = 1; | ||||
|         sampleRate = 8000; | ||||
|         byteRate = 16000; | ||||
|         blockAlign = 2; | ||||
|         bitsPerSample = 16; | ||||
|         subChunk2Size = 0; | ||||
|         valid = true; | ||||
|     } | ||||
| 
 | ||||
|     public WaveHeader(InputStream inputStream) { | ||||
|         valid = loadHeader(inputStream); | ||||
|     } | ||||
| 
 | ||||
|     private boolean loadHeader(InputStream inputStream) { | ||||
| 
 | ||||
|         byte[] headerBuffer = new byte[HEADER_BYTE_LENGTH]; | ||||
|         try { | ||||
|             inputStream.read(headerBuffer); | ||||
| 
 | ||||
|             // read header | ||||
|             int pointer = 0; | ||||
|             chunkId = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], | ||||
|                             headerBuffer[pointer++]}); | ||||
|             // little endian | ||||
|             chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff << 24); | ||||
|             format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], | ||||
|                             headerBuffer[pointer++]}); | ||||
|             subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], | ||||
|                             headerBuffer[pointer++], headerBuffer[pointer++]}); | ||||
|             subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; | ||||
|             audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); | ||||
|             channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); | ||||
|             sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; | ||||
|             byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; | ||||
|             blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); | ||||
|             bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); | ||||
|             subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], | ||||
|                             headerBuffer[pointer++], headerBuffer[pointer++]}); | ||||
|             subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 | ||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; | ||||
|             // end read header | ||||
| 
 | ||||
|             // the inputStream should be closed outside this method | ||||
| 
 | ||||
|             // dis.close(); | ||||
| 
 | ||||
|         } catch (IOException e) { | ||||
|             log.error("",e); | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         if (bitsPerSample != 8 && bitsPerSample != 16) { | ||||
|             System.err.println("WaveHeader: only supports bitsPerSample 8 or 16"); | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         // check the format is support | ||||
|         if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) { | ||||
|             return true; | ||||
|         } else { | ||||
|             System.err.println("WaveHeader: Unsupported header format"); | ||||
|         } | ||||
| 
 | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     public boolean isValid() { | ||||
|         return valid; | ||||
|     } | ||||
| 
 | ||||
|     public String getChunkId() { | ||||
|         return chunkId; | ||||
|     } | ||||
| 
 | ||||
|     public long getChunkSize() { | ||||
|         return chunkSize; | ||||
|     } | ||||
| 
 | ||||
|     public String getFormat() { | ||||
|         return format; | ||||
|     } | ||||
| 
 | ||||
|     public String getSubChunk1Id() { | ||||
|         return subChunk1Id; | ||||
|     } | ||||
| 
 | ||||
|     public long getSubChunk1Size() { | ||||
|         return subChunk1Size; | ||||
|     } | ||||
| 
 | ||||
|     public int getAudioFormat() { | ||||
|         return audioFormat; | ||||
|     } | ||||
| 
 | ||||
|     public int getChannels() { | ||||
|         return channels; | ||||
|     } | ||||
| 
 | ||||
|     public int getSampleRate() { | ||||
|         return (int) sampleRate; | ||||
|     } | ||||
| 
 | ||||
|     public int getByteRate() { | ||||
|         return (int) byteRate; | ||||
|     } | ||||
| 
 | ||||
|     public int getBlockAlign() { | ||||
|         return blockAlign; | ||||
|     } | ||||
| 
 | ||||
|     public int getBitsPerSample() { | ||||
|         return bitsPerSample; | ||||
|     } | ||||
| 
 | ||||
|     public String getSubChunk2Id() { | ||||
|         return subChunk2Id; | ||||
|     } | ||||
| 
 | ||||
|     public long getSubChunk2Size() { | ||||
|         return subChunk2Size; | ||||
|     } | ||||
| 
 | ||||
|     public void setSampleRate(int sampleRate) { | ||||
|         int newSubChunk2Size = (int) (this.subChunk2Size * sampleRate / this.sampleRate); | ||||
|         // if num bytes for each sample is even, the size of newSubChunk2Size also needed to be in even number | ||||
|         if ((bitsPerSample / 8) % 2 == 0) { | ||||
|             if (newSubChunk2Size % 2 != 0) { | ||||
|                 newSubChunk2Size++; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         this.sampleRate = sampleRate; | ||||
|         this.byteRate = sampleRate * bitsPerSample / 8; | ||||
|         this.chunkSize = newSubChunk2Size + 36; | ||||
|         this.subChunk2Size = newSubChunk2Size; | ||||
|     } | ||||
| 
 | ||||
|     public void setChunkId(String chunkId) { | ||||
|         this.chunkId = chunkId; | ||||
|     } | ||||
| 
 | ||||
|     public void setChunkSize(long chunkSize) { | ||||
|         this.chunkSize = chunkSize; | ||||
|     } | ||||
| 
 | ||||
|     public void setFormat(String format) { | ||||
|         this.format = format; | ||||
|     } | ||||
| 
 | ||||
|     public void setSubChunk1Id(String subChunk1Id) { | ||||
|         this.subChunk1Id = subChunk1Id; | ||||
|     } | ||||
| 
 | ||||
|     public void setSubChunk1Size(long subChunk1Size) { | ||||
|         this.subChunk1Size = subChunk1Size; | ||||
|     } | ||||
| 
 | ||||
|     public void setAudioFormat(int audioFormat) { | ||||
|         this.audioFormat = audioFormat; | ||||
|     } | ||||
| 
 | ||||
|     public void setChannels(int channels) { | ||||
|         this.channels = channels; | ||||
|     } | ||||
| 
 | ||||
|     public void setByteRate(long byteRate) { | ||||
|         this.byteRate = byteRate; | ||||
|     } | ||||
| 
 | ||||
|     public void setBlockAlign(int blockAlign) { | ||||
|         this.blockAlign = blockAlign; | ||||
|     } | ||||
| 
 | ||||
|     public void setBitsPerSample(int bitsPerSample) { | ||||
|         this.bitsPerSample = bitsPerSample; | ||||
|     } | ||||
| 
 | ||||
|     public void setSubChunk2Id(String subChunk2Id) { | ||||
|         this.subChunk2Id = subChunk2Id; | ||||
|     } | ||||
| 
 | ||||
|     public void setSubChunk2Size(long subChunk2Size) { | ||||
|         this.subChunk2Size = subChunk2Size; | ||||
|     } | ||||
| 
 | ||||
|     public String toString() { | ||||
| 
 | ||||
|         StringBuilder sb = new StringBuilder(); | ||||
|         sb.append("chunkId: " + chunkId); | ||||
|         sb.append("\n"); | ||||
|         sb.append("chunkSize: " + chunkSize); | ||||
|         sb.append("\n"); | ||||
|         sb.append("format: " + format); | ||||
|         sb.append("\n"); | ||||
|         sb.append("subChunk1Id: " + subChunk1Id); | ||||
|         sb.append("\n"); | ||||
|         sb.append("subChunk1Size: " + subChunk1Size); | ||||
|         sb.append("\n"); | ||||
|         sb.append("audioFormat: " + audioFormat); | ||||
|         sb.append("\n"); | ||||
|         sb.append("channels: " + channels); | ||||
|         sb.append("\n"); | ||||
|         sb.append("sampleRate: " + sampleRate); | ||||
|         sb.append("\n"); | ||||
|         sb.append("byteRate: " + byteRate); | ||||
|         sb.append("\n"); | ||||
|         sb.append("blockAlign: " + blockAlign); | ||||
|         sb.append("\n"); | ||||
|         sb.append("bitsPerSample: " + bitsPerSample); | ||||
|         sb.append("\n"); | ||||
|         sb.append("subChunk2Id: " + subChunk2Id); | ||||
|         sb.append("\n"); | ||||
|         sb.append("subChunk2Size: " + subChunk2Size); | ||||
|         return sb.toString(); | ||||
|     } | ||||
| } | ||||
| @ -1,81 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.dsp; | ||||
| 
 | ||||
| import org.jtransforms.fft.DoubleFFT_1D; | ||||
| 
 | ||||
| public class FastFourierTransform { | ||||
| 
 | ||||
|     /** | ||||
|      * Get the frequency intensities | ||||
|      * | ||||
|      * @param amplitudes amplitudes of the signal. Format depends on value of complex | ||||
|      * @param complex    if true, amplitudes is assumed to be complex interlaced (re = even, im = odd), if false amplitudes | ||||
|      *                   are assumed to be real valued. | ||||
|      * @return intensities of each frequency unit: mag[frequency_unit]=intensity | ||||
|      */ | ||||
|     public double[] getMagnitudes(double[] amplitudes, boolean complex) { | ||||
| 
 | ||||
|         final int sampleSize = amplitudes.length; | ||||
|         final int nrofFrequencyBins = sampleSize / 2; | ||||
| 
 | ||||
| 
 | ||||
|         // call the fft and transform the complex numbers | ||||
|         if (complex) { | ||||
|             DoubleFFT_1D fft = new DoubleFFT_1D(nrofFrequencyBins); | ||||
|             fft.complexForward(amplitudes); | ||||
|         } else { | ||||
|             DoubleFFT_1D fft = new DoubleFFT_1D(sampleSize); | ||||
|             fft.realForward(amplitudes); | ||||
|             // amplitudes[1] contains re[sampleSize/2] or im[(sampleSize-1) / 2] (depending on whether sampleSize is odd or even) | ||||
|             // Discard it as it is useless without the other part | ||||
|             // im part dc bin is always 0 for real input | ||||
|             amplitudes[1] = 0; | ||||
|         } | ||||
|         // end call the fft and transform the complex numbers | ||||
| 
 | ||||
|         // even indexes (0,2,4,6,...) are real parts | ||||
|         // odd indexes (1,3,5,7,...) are img parts | ||||
|         double[] mag = new double[nrofFrequencyBins]; | ||||
|         for (int i = 0; i < nrofFrequencyBins; i++) { | ||||
|             final int f = 2 * i; | ||||
|             mag[i] = Math.sqrt(amplitudes[f] * amplitudes[f] + amplitudes[f + 1] * amplitudes[f + 1]); | ||||
|         } | ||||
| 
 | ||||
|         return mag; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the frequency intensities. Backwards compatible with previous versions w.r.t to number of frequency bins. | ||||
|      * Use getMagnitudes(amplitudes, true) to get all bins. | ||||
|      * | ||||
|      * @param amplitudes complex-valued signal to transform. Even indexes are real and odd indexes are img | ||||
|      * @return intensities of each frequency unit: mag[frequency_unit]=intensity | ||||
|      */ | ||||
|     public double[] getMagnitudes(double[] amplitudes) { | ||||
|         double[] magnitudes = getMagnitudes(amplitudes, true); | ||||
| 
 | ||||
|         double[] halfOfMagnitudes = new double[magnitudes.length/2]; | ||||
|         System.arraycopy(magnitudes, 0,halfOfMagnitudes, 0, halfOfMagnitudes.length); | ||||
|         return halfOfMagnitudes; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,66 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.dsp; | ||||
| 
 | ||||
| public class LinearInterpolation { | ||||
| 
 | ||||
|     public LinearInterpolation() { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Do interpolation on the samples according to the original and destinated sample rates | ||||
|      *  | ||||
|      * @param oldSampleRate	sample rate of the original samples | ||||
|      * @param newSampleRate	sample rate of the interpolated samples | ||||
|      * @param samples	original samples | ||||
|      * @return interpolated samples | ||||
|      */ | ||||
|     public short[] interpolate(int oldSampleRate, int newSampleRate, short[] samples) { | ||||
| 
 | ||||
|         if (oldSampleRate == newSampleRate) { | ||||
|             return samples; | ||||
|         } | ||||
| 
 | ||||
|         int newLength = Math.round(((float) samples.length / oldSampleRate * newSampleRate)); | ||||
|         float lengthMultiplier = (float) newLength / samples.length; | ||||
|         short[] interpolatedSamples = new short[newLength]; | ||||
| 
 | ||||
|         // interpolate the value by the linear equation y=mx+c         | ||||
|         for (int i = 0; i < newLength; i++) { | ||||
| 
 | ||||
|             // get the nearest positions for the interpolated point | ||||
|             float currentPosition = i / lengthMultiplier; | ||||
|             int nearestLeftPosition = (int) currentPosition; | ||||
|             int nearestRightPosition = nearestLeftPosition + 1; | ||||
|             if (nearestRightPosition >= samples.length) { | ||||
|                 nearestRightPosition = samples.length - 1; | ||||
|             } | ||||
| 
 | ||||
|             float slope = samples[nearestRightPosition] - samples[nearestLeftPosition]; // delta x is 1 | ||||
|             float positionFromLeft = currentPosition - nearestLeftPosition; | ||||
| 
 | ||||
|             interpolatedSamples[i] = (short) (slope * positionFromLeft + samples[nearestLeftPosition]); // y=mx+c | ||||
|         } | ||||
| 
 | ||||
|         return interpolatedSamples; | ||||
|     } | ||||
| } | ||||
| @ -1,84 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.dsp; | ||||
| 
 | ||||
| public class Resampler { | ||||
| 
 | ||||
|     public Resampler() {} | ||||
| 
 | ||||
|     /** | ||||
|      * Do resampling. Currently the amplitude is stored by short such that maximum bitsPerSample is 16 (bytePerSample is 2) | ||||
|      *  | ||||
|      * @param sourceData	The source data in bytes | ||||
|      * @param bitsPerSample	How many bits represents one sample (currently supports max. bitsPerSample=16)  | ||||
|      * @param sourceRate	Sample rate of the source data | ||||
|      * @param targetRate	Sample rate of the target data | ||||
|      * @return re-sampled data | ||||
|      */ | ||||
|     public byte[] reSample(byte[] sourceData, int bitsPerSample, int sourceRate, int targetRate) { | ||||
| 
 | ||||
|         // make the bytes to amplitudes first | ||||
|         int bytePerSample = bitsPerSample / 8; | ||||
|         int numSamples = sourceData.length / bytePerSample; | ||||
|         short[] amplitudes = new short[numSamples]; // 16 bit, use a short to store | ||||
| 
 | ||||
|         int pointer = 0; | ||||
|         for (int i = 0; i < numSamples; i++) { | ||||
|             short amplitude = 0; | ||||
|             for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) { | ||||
|                 // little endian | ||||
|                 amplitude |= (short) ((sourceData[pointer++] & 0xFF) << (byteNumber * 8)); | ||||
|             } | ||||
|             amplitudes[i] = amplitude; | ||||
|         } | ||||
|         // end make the amplitudes | ||||
| 
 | ||||
|         // do interpolation | ||||
|         LinearInterpolation reSample = new LinearInterpolation(); | ||||
|         short[] targetSample = reSample.interpolate(sourceRate, targetRate, amplitudes); | ||||
|         int targetLength = targetSample.length; | ||||
|         // end do interpolation | ||||
| 
 | ||||
|         // TODO: Remove the high frequency signals with a digital filter, leaving a signal containing only half-sample-rated frequency information, but still sampled at a rate of target sample rate. Usually FIR is used | ||||
| 
 | ||||
|         // end resample the amplitudes | ||||
| 
 | ||||
|         // convert the amplitude to bytes | ||||
|         byte[] bytes; | ||||
|         if (bytePerSample == 1) { | ||||
|             bytes = new byte[targetLength]; | ||||
|             for (int i = 0; i < targetLength; i++) { | ||||
|                 bytes[i] = (byte) targetSample[i]; | ||||
|             } | ||||
|         } else { | ||||
|             // suppose bytePerSample==2 | ||||
|             bytes = new byte[targetLength * 2]; | ||||
|             for (int i = 0; i < targetSample.length; i++) { | ||||
|                 // little endian			 | ||||
|                 bytes[i * 2] = (byte) (targetSample[i] & 0xff); | ||||
|                 bytes[i * 2 + 1] = (byte) ((targetSample[i] >> 8) & 0xff); | ||||
|             } | ||||
|         } | ||||
|         // end convert the amplitude to bytes | ||||
| 
 | ||||
|         return bytes; | ||||
|     } | ||||
| } | ||||
| @ -1,95 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.dsp; | ||||
| 
 | ||||
| public class WindowFunction { | ||||
| 
 | ||||
|     public static final int RECTANGULAR = 0; | ||||
|     public static final int BARTLETT = 1; | ||||
|     public static final int HANNING = 2; | ||||
|     public static final int HAMMING = 3; | ||||
|     public static final int BLACKMAN = 4; | ||||
| 
 | ||||
|     int windowType = 0; // defaults to rectangular window | ||||
| 
 | ||||
|     public WindowFunction() {} | ||||
| 
 | ||||
|     public void setWindowType(int wt) { | ||||
|         windowType = wt; | ||||
|     } | ||||
| 
 | ||||
|     public void setWindowType(String w) { | ||||
|         if (w.toUpperCase().equals("RECTANGULAR")) | ||||
|             windowType = RECTANGULAR; | ||||
|         if (w.toUpperCase().equals("BARTLETT")) | ||||
|             windowType = BARTLETT; | ||||
|         if (w.toUpperCase().equals("HANNING")) | ||||
|             windowType = HANNING; | ||||
|         if (w.toUpperCase().equals("HAMMING")) | ||||
|             windowType = HAMMING; | ||||
|         if (w.toUpperCase().equals("BLACKMAN")) | ||||
|             windowType = BLACKMAN; | ||||
|     } | ||||
| 
 | ||||
|     public int getWindowType() { | ||||
|         return windowType; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Generate a window | ||||
|      *  | ||||
|      * @param nSamples	size of the window | ||||
|      * @return	window in array | ||||
|      */ | ||||
|     public double[] generate(int nSamples) { | ||||
|         // generate nSamples window function values | ||||
|         // for index values 0 .. nSamples - 1 | ||||
|         int m = nSamples / 2; | ||||
|         double r; | ||||
|         double pi = Math.PI; | ||||
|         double[] w = new double[nSamples]; | ||||
|         switch (windowType) { | ||||
|             case BARTLETT: // Bartlett (triangular) window | ||||
|                 for (int n = 0; n < nSamples; n++) | ||||
|                     w[n] = 1.0f - Math.abs(n - m) / m; | ||||
|                 break; | ||||
|             case HANNING: // Hanning window | ||||
|                 r = pi / (m + 1); | ||||
|                 for (int n = -m; n < m; n++) | ||||
|                     w[m + n] = 0.5f + 0.5f * Math.cos(n * r); | ||||
|                 break; | ||||
|             case HAMMING: // Hamming window | ||||
|                 r = pi / m; | ||||
|                 for (int n = -m; n < m; n++) | ||||
|                     w[m + n] = 0.54f + 0.46f * Math.cos(n * r); | ||||
|                 break; | ||||
|             case BLACKMAN: // Blackman window | ||||
|                 r = pi / m; | ||||
|                 for (int n = -m; n < m; n++) | ||||
|                     w[m + n] = 0.42f + 0.5f * Math.cos(n * r) + 0.08f * Math.cos(2 * n * r); | ||||
|                 break; | ||||
|             default: // Rectangular window function | ||||
|                 for (int n = 0; n < nSamples; n++) | ||||
|                     w[n] = 1.0f; | ||||
|         } | ||||
|         return w; | ||||
|     } | ||||
| } | ||||
| @ -1,21 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.dsp; | ||||
| @ -1,67 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.extension; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.audio.Wave; | ||||
| 
 | ||||
| public class NormalizedSampleAmplitudes { | ||||
| 
 | ||||
|     private Wave wave; | ||||
|     private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame | ||||
| 
 | ||||
|     public NormalizedSampleAmplitudes(Wave wave) { | ||||
|         this.wave = wave; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      *  | ||||
|      * Get normalized amplitude of each frame | ||||
|      *  | ||||
|      * @return	array of normalized amplitudes(signed 16 bit): normalizedAmplitudes[frame]=amplitude | ||||
|      */ | ||||
|     public double[] getNormalizedAmplitudes() { | ||||
| 
 | ||||
|         if (normalizedAmplitudes == null) { | ||||
| 
 | ||||
|             boolean signed = true; | ||||
| 
 | ||||
|             // usually 8bit is unsigned | ||||
|             if (wave.getWaveHeader().getBitsPerSample() == 8) { | ||||
|                 signed = false; | ||||
|             } | ||||
| 
 | ||||
|             short[] amplitudes = wave.getSampleAmplitudes(); | ||||
|             int numSamples = amplitudes.length; | ||||
|             int maxAmplitude = 1 << (wave.getWaveHeader().getBitsPerSample() - 1); | ||||
| 
 | ||||
|             if (!signed) { // one more bit for unsigned value | ||||
|                 maxAmplitude <<= 1; | ||||
|             } | ||||
| 
 | ||||
|             normalizedAmplitudes = new double[numSamples]; | ||||
|             for (int i = 0; i < numSamples; i++) { | ||||
|                 normalizedAmplitudes[i] = (double) amplitudes[i] / maxAmplitude; | ||||
|             } | ||||
|         } | ||||
|         return normalizedAmplitudes; | ||||
|     } | ||||
| } | ||||
| @ -1,214 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.extension; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.audio.Wave; | ||||
| import org.datavec.audio.dsp.FastFourierTransform; | ||||
| import org.datavec.audio.dsp.WindowFunction; | ||||
| 
 | ||||
| public class Spectrogram { | ||||
| 
 | ||||
|     public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024; | ||||
|     public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping | ||||
| 
 | ||||
|     private Wave wave; | ||||
|     private double[][] spectrogram; // relative spectrogram | ||||
|     private double[][] absoluteSpectrogram; // absolute spectrogram | ||||
|     private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2 | ||||
|     private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping | ||||
|     private int numFrames; // number of frames of the spectrogram | ||||
|     private int framesPerSecond; // frame per second of the spectrogram | ||||
|     private int numFrequencyUnit; // number of y-axis unit | ||||
|     private double unitFrequency; // frequency per y-axis unit | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      * | ||||
|      * @param wave | ||||
|      */ | ||||
|     public Spectrogram(Wave wave) { | ||||
|         this.wave = wave; | ||||
|         // default | ||||
|         this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE; | ||||
|         this.overlapFactor = SPECTROGRAM_DEFAULT_OVERLAP_FACTOR; | ||||
|         buildSpectrogram(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      * | ||||
|      * @param wave | ||||
|      * @param fftSampleSize	number of sample in fft, the value needed to be a number to power of 2 | ||||
|      * @param overlapFactor	1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping | ||||
|      */ | ||||
|     public Spectrogram(Wave wave, int fftSampleSize, int overlapFactor) { | ||||
|         this.wave = wave; | ||||
| 
 | ||||
|         if (Integer.bitCount(fftSampleSize) == 1) { | ||||
|             this.fftSampleSize = fftSampleSize; | ||||
|         } else { | ||||
|             System.err.print("The input number must be a power of 2"); | ||||
|             this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE; | ||||
|         } | ||||
| 
 | ||||
|         this.overlapFactor = overlapFactor; | ||||
| 
 | ||||
|         buildSpectrogram(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Build spectrogram | ||||
|      */ | ||||
|     private void buildSpectrogram() { | ||||
| 
 | ||||
|         short[] amplitudes = wave.getSampleAmplitudes(); | ||||
|         int numSamples = amplitudes.length; | ||||
| 
 | ||||
|         int pointer = 0; | ||||
|         // overlapping | ||||
|         if (overlapFactor > 1) { | ||||
|             int numOverlappedSamples = numSamples * overlapFactor; | ||||
|             int backSamples = fftSampleSize * (overlapFactor - 1) / overlapFactor; | ||||
|             short[] overlapAmp = new short[numOverlappedSamples]; | ||||
|             pointer = 0; | ||||
|             for (int i = 0; i < amplitudes.length; i++) { | ||||
|                 overlapAmp[pointer++] = amplitudes[i]; | ||||
|                 if (pointer % fftSampleSize == 0) { | ||||
|                     // overlap | ||||
|                     i -= backSamples; | ||||
|                 } | ||||
|             } | ||||
|             numSamples = numOverlappedSamples; | ||||
|             amplitudes = overlapAmp; | ||||
|         } | ||||
|         // end overlapping | ||||
| 
 | ||||
|         numFrames = numSamples / fftSampleSize; | ||||
|         framesPerSecond = (int) (numFrames / wave.length()); | ||||
| 
 | ||||
|         // set signals for fft | ||||
|         WindowFunction window = new WindowFunction(); | ||||
|         window.setWindowType("Hamming"); | ||||
|         double[] win = window.generate(fftSampleSize); | ||||
| 
 | ||||
|         double[][] signals = new double[numFrames][]; | ||||
|         for (int f = 0; f < numFrames; f++) { | ||||
|             signals[f] = new double[fftSampleSize]; | ||||
|             int startSample = f * fftSampleSize; | ||||
|             for (int n = 0; n < fftSampleSize; n++) { | ||||
|                 signals[f][n] = amplitudes[startSample + n] * win[n]; | ||||
|             } | ||||
|         } | ||||
|         // end set signals for fft | ||||
| 
 | ||||
|         absoluteSpectrogram = new double[numFrames][]; | ||||
|         // for each frame in signals, do fft on it | ||||
|         FastFourierTransform fft = new FastFourierTransform(); | ||||
|         for (int i = 0; i < numFrames; i++) { | ||||
|             absoluteSpectrogram[i] = fft.getMagnitudes(signals[i], false); | ||||
|         } | ||||
| 
 | ||||
|         if (absoluteSpectrogram.length > 0) { | ||||
| 
 | ||||
|             numFrequencyUnit = absoluteSpectrogram[0].length; | ||||
|             unitFrequency = (double) wave.getWaveHeader().getSampleRate() / 2 / numFrequencyUnit; // frequency could be caught within the half of nSamples according to Nyquist theory | ||||
| 
 | ||||
|             // normalization of absoultSpectrogram | ||||
|             spectrogram = new double[numFrames][numFrequencyUnit]; | ||||
| 
 | ||||
|             // set max and min amplitudes | ||||
|             double maxAmp = Double.MIN_VALUE; | ||||
|             double minAmp = Double.MAX_VALUE; | ||||
|             for (int i = 0; i < numFrames; i++) { | ||||
|                 for (int j = 0; j < numFrequencyUnit; j++) { | ||||
|                     if (absoluteSpectrogram[i][j] > maxAmp) { | ||||
|                         maxAmp = absoluteSpectrogram[i][j]; | ||||
|                     } else if (absoluteSpectrogram[i][j] < minAmp) { | ||||
|                         minAmp = absoluteSpectrogram[i][j]; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             // end set max and min amplitudes | ||||
| 
 | ||||
|             // normalization | ||||
|             // avoiding divided by zero  | ||||
|             double minValidAmp = 0.00000000001F; | ||||
|             if (minAmp == 0) { | ||||
|                 minAmp = minValidAmp; | ||||
|             } | ||||
| 
 | ||||
|             double diff = Math.log10(maxAmp / minAmp); // perceptual difference | ||||
|             for (int i = 0; i < numFrames; i++) { | ||||
|                 for (int j = 0; j < numFrequencyUnit; j++) { | ||||
|                     if (absoluteSpectrogram[i][j] < minValidAmp) { | ||||
|                         spectrogram[i][j] = 0; | ||||
|                     } else { | ||||
|                         spectrogram[i][j] = (Math.log10(absoluteSpectrogram[i][j] / minAmp)) / diff; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             // end normalization | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get spectrogram: spectrogram[time][frequency]=intensity | ||||
|      * | ||||
|      * @return	logarithm normalized spectrogram | ||||
|      */ | ||||
|     public double[][] getNormalizedSpectrogramData() { | ||||
|         return spectrogram; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get spectrogram: spectrogram[time][frequency]=intensity | ||||
|      * | ||||
|      * @return	absolute spectrogram | ||||
|      */ | ||||
|     public double[][] getAbsoluteSpectrogramData() { | ||||
|         return absoluteSpectrogram; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumFrames() { | ||||
|         return numFrames; | ||||
|     } | ||||
| 
 | ||||
|     public int getFramesPerSecond() { | ||||
|         return framesPerSecond; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumFrequencyUnit() { | ||||
|         return numFrequencyUnit; | ||||
|     } | ||||
| 
 | ||||
|     public double getUnitFrequency() { | ||||
|         return unitFrequency; | ||||
|     } | ||||
| 
 | ||||
|     public int getFftSampleSize() { | ||||
|         return fftSampleSize; | ||||
|     } | ||||
| 
 | ||||
|     public int getOverlapFactor() { | ||||
|         return overlapFactor; | ||||
|     } | ||||
| } | ||||
| @ -1,272 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.datavec.audio.Wave; | ||||
| import org.datavec.audio.WaveHeader; | ||||
| import org.datavec.audio.dsp.Resampler; | ||||
| import org.datavec.audio.extension.Spectrogram; | ||||
| import org.datavec.audio.processor.TopManyPointsProcessorChain; | ||||
| import org.datavec.audio.properties.FingerprintProperties; | ||||
| 
 | ||||
| import java.io.FileInputStream; | ||||
| import java.io.FileOutputStream; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.util.Iterator; | ||||
| import java.util.LinkedList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class FingerprintManager { | ||||
| 
 | ||||
|     private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); | ||||
|     private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame(); | ||||
|     private int overlapFactor = fingerprintProperties.getOverlapFactor(); | ||||
|     private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame(); | ||||
|     private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      */ | ||||
|     public FingerprintManager() { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Extract fingerprint from Wave object | ||||
|      *  | ||||
|      * @param wave	Wave Object to be extracted fingerprint | ||||
|      * @return fingerprint in bytes | ||||
|      */ | ||||
|     public byte[] extractFingerprint(Wave wave) { | ||||
| 
 | ||||
|         int[][] coordinates; // coordinates[x][0..3]=y0..y3 | ||||
|         byte[] fingerprint = new byte[0]; | ||||
| 
 | ||||
|         // resample to target rate | ||||
|         Resampler resampler = new Resampler(); | ||||
|         int sourceRate = wave.getWaveHeader().getSampleRate(); | ||||
|         int targetRate = fingerprintProperties.getSampleRate(); | ||||
| 
 | ||||
|         byte[] resampledWaveData = resampler.reSample(wave.getBytes(), wave.getWaveHeader().getBitsPerSample(), | ||||
|                         sourceRate, targetRate); | ||||
| 
 | ||||
|         // update the wave header | ||||
|         WaveHeader resampledWaveHeader = wave.getWaveHeader(); | ||||
|         resampledWaveHeader.setSampleRate(targetRate); | ||||
| 
 | ||||
|         // make resampled wave | ||||
|         Wave resampledWave = new Wave(resampledWaveHeader, resampledWaveData); | ||||
|         // end resample to target rate | ||||
| 
 | ||||
|         // get spectrogram's data | ||||
|         Spectrogram spectrogram = resampledWave.getSpectrogram(sampleSizePerFrame, overlapFactor); | ||||
|         double[][] spectorgramData = spectrogram.getNormalizedSpectrogramData(); | ||||
| 
 | ||||
|         List<Integer>[] pointsLists = getRobustPointList(spectorgramData); | ||||
|         int numFrames = pointsLists.length; | ||||
| 
 | ||||
|         // prepare fingerprint bytes | ||||
|         coordinates = new int[numFrames][numRobustPointsPerFrame]; | ||||
| 
 | ||||
|         for (int x = 0; x < numFrames; x++) { | ||||
|             if (pointsLists[x].size() == numRobustPointsPerFrame) { | ||||
|                 Iterator<Integer> pointsListsIterator = pointsLists[x].iterator(); | ||||
|                 for (int y = 0; y < numRobustPointsPerFrame; y++) { | ||||
|                     coordinates[x][y] = pointsListsIterator.next(); | ||||
|                 } | ||||
|             } else { | ||||
|                 // use -1 to fill the empty byte | ||||
|                 for (int y = 0; y < numRobustPointsPerFrame; y++) { | ||||
|                     coordinates[x][y] = -1; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         // end make fingerprint | ||||
| 
 | ||||
|         // for each valid coordinate, append with its intensity | ||||
|         List<Byte> byteList = new LinkedList<Byte>(); | ||||
|         for (int i = 0; i < numFrames; i++) { | ||||
|             for (int j = 0; j < numRobustPointsPerFrame; j++) { | ||||
|                 if (coordinates[i][j] != -1) { | ||||
|                     // first 2 bytes is x | ||||
|                     byteList.add((byte) (i >> 8)); | ||||
|                     byteList.add((byte) i); | ||||
| 
 | ||||
|                     // next 2 bytes is y | ||||
|                     int y = coordinates[i][j]; | ||||
|                     byteList.add((byte) (y >> 8)); | ||||
|                     byteList.add((byte) y); | ||||
| 
 | ||||
|                     // next 4 bytes is intensity | ||||
|                     int intensity = (int) (spectorgramData[i][y] * Integer.MAX_VALUE); // spectorgramData is ranged from 0~1 | ||||
|                     byteList.add((byte) (intensity >> 24)); | ||||
|                     byteList.add((byte) (intensity >> 16)); | ||||
|                     byteList.add((byte) (intensity >> 8)); | ||||
|                     byteList.add((byte) intensity); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         // end for each valid coordinate, append with its intensity | ||||
| 
 | ||||
|         fingerprint = new byte[byteList.size()]; | ||||
|         Iterator<Byte> byteListIterator = byteList.iterator(); | ||||
|         int pointer = 0; | ||||
|         while (byteListIterator.hasNext()) { | ||||
|             fingerprint[pointer++] = byteListIterator.next(); | ||||
|         } | ||||
| 
 | ||||
|         return fingerprint; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get bytes from fingerprint file | ||||
|      *  | ||||
|      * @param fingerprintFile	fingerprint filename | ||||
|      * @return fingerprint in bytes | ||||
|      */ | ||||
|     public byte[] getFingerprintFromFile(String fingerprintFile) { | ||||
|         byte[] fingerprint = null; | ||||
|         try { | ||||
|             InputStream fis = new FileInputStream(fingerprintFile); | ||||
|             fingerprint = getFingerprintFromInputStream(fis); | ||||
|             fis.close(); | ||||
|         } catch (IOException e) { | ||||
|             log.error("",e); | ||||
|         } | ||||
|         return fingerprint; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get bytes from fingerprint inputstream | ||||
|      *  | ||||
|      * @param inputStream	fingerprint inputstream | ||||
|      * @return fingerprint in bytes | ||||
|      */ | ||||
|     public byte[] getFingerprintFromInputStream(InputStream inputStream) { | ||||
|         byte[] fingerprint = null; | ||||
|         try { | ||||
|             fingerprint = new byte[inputStream.available()]; | ||||
|             inputStream.read(fingerprint); | ||||
|         } catch (IOException e) { | ||||
|             log.error("",e); | ||||
|         } | ||||
|         return fingerprint; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Save fingerprint to a file | ||||
|      *  | ||||
|      * @param fingerprint	fingerprint bytes | ||||
|      * @param filename		fingerprint filename | ||||
|      * @see	FingerprintManager file saved | ||||
|      */ | ||||
|     public void saveFingerprintAsFile(byte[] fingerprint, String filename) { | ||||
| 
 | ||||
|         FileOutputStream fileOutputStream; | ||||
|         try { | ||||
|             fileOutputStream = new FileOutputStream(filename); | ||||
|             fileOutputStream.write(fingerprint); | ||||
|             fileOutputStream.close(); | ||||
|         } catch (IOException e) { | ||||
|             log.error("",e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     // robustLists[x]=y1,y2,y3,... | ||||
|     private List<Integer>[] getRobustPointList(double[][] spectrogramData) { | ||||
| 
 | ||||
|         int numX = spectrogramData.length; | ||||
|         int numY = spectrogramData[0].length; | ||||
| 
 | ||||
|         double[][] allBanksIntensities = new double[numX][numY]; | ||||
|         int bandwidthPerBank = numY / numFilterBanks; | ||||
| 
 | ||||
|         for (int b = 0; b < numFilterBanks; b++) { | ||||
| 
 | ||||
|             double[][] bankIntensities = new double[numX][bandwidthPerBank]; | ||||
| 
 | ||||
|             for (int i = 0; i < numX; i++) { | ||||
|                 System.arraycopy(spectrogramData[i], b * bandwidthPerBank, bankIntensities[i], 0, bandwidthPerBank); | ||||
|             } | ||||
| 
 | ||||
|             // get the most robust point in each filter bank | ||||
|             TopManyPointsProcessorChain processorChain = new TopManyPointsProcessorChain(bankIntensities, 1); | ||||
|             double[][] processedIntensities = processorChain.getIntensities(); | ||||
| 
 | ||||
|             for (int i = 0; i < numX; i++) { | ||||
|                 System.arraycopy(processedIntensities[i], 0, allBanksIntensities[i], b * bandwidthPerBank, | ||||
|                                 bandwidthPerBank); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         List<int[]> robustPointList = new LinkedList<int[]>(); | ||||
| 
 | ||||
|         // find robust points | ||||
|         for (int i = 0; i < allBanksIntensities.length; i++) { | ||||
|             for (int j = 0; j < allBanksIntensities[i].length; j++) { | ||||
|                 if (allBanksIntensities[i][j] > 0) { | ||||
| 
 | ||||
|                     int[] point = new int[] {i, j}; | ||||
|                     //System.out.println(i+","+frequency); | ||||
|                     robustPointList.add(point); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         // end find robust points | ||||
| 
 | ||||
|         List<Integer>[] robustLists = new LinkedList[spectrogramData.length]; | ||||
|         for (int i = 0; i < robustLists.length; i++) { | ||||
|             robustLists[i] = new LinkedList<>(); | ||||
|         } | ||||
| 
 | ||||
|         // robustLists[x]=y1,y2,y3,... | ||||
|         for (int[] coor : robustPointList) { | ||||
|             robustLists[coor[0]].add(coor[1]); | ||||
|         } | ||||
| 
 | ||||
|         // return the list per frame | ||||
|         return robustLists; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Number of frames in a fingerprint | ||||
|      * Each frame lengths 8 bytes | ||||
|      * Usually there is more than one point in each frame, so it cannot simply divide the bytes length by 8 | ||||
|      * Last 8 byte of thisFingerprint is the last frame of this wave | ||||
|      * First 2 byte of the last 8 byte is the x position of this wave, i.e. (number_of_frames-1) of this wave	  | ||||
|      *  | ||||
|      * @param fingerprint	fingerprint bytes | ||||
|      * @return number of frames of the fingerprint | ||||
|      */ | ||||
|     public static int getNumFrames(byte[] fingerprint) { | ||||
| 
 | ||||
|         if (fingerprint.length < 8) { | ||||
|             return 0; | ||||
|         } | ||||
| 
 | ||||
|         // get the last x-coordinate (length-8&length-7)bytes from fingerprint | ||||
|         return ((fingerprint[fingerprint.length - 8] & 0xff) << 8 | (fingerprint[fingerprint.length - 7] & 0xff)) + 1; | ||||
|     } | ||||
| } | ||||
| @ -1,107 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.audio.properties.FingerprintProperties; | ||||
| 
 | ||||
| public class FingerprintSimilarity { | ||||
| 
 | ||||
|     private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); | ||||
|     private int mostSimilarFramePosition; | ||||
|     private float score; | ||||
|     private float similarity; | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      */ | ||||
|     public FingerprintSimilarity() { | ||||
|         mostSimilarFramePosition = Integer.MIN_VALUE; | ||||
|         score = -1; | ||||
|         similarity = -1; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the most similar position in terms of frame number | ||||
|      *  | ||||
|      * @return most similar frame position | ||||
|      */ | ||||
|     public int getMostSimilarFramePosition() { | ||||
|         return mostSimilarFramePosition; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Set the most similar position in terms of frame number | ||||
|      *  | ||||
|      * @param mostSimilarFramePosition | ||||
|      */ | ||||
|     public void setMostSimilarFramePosition(int mostSimilarFramePosition) { | ||||
|         this.mostSimilarFramePosition = mostSimilarFramePosition; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the similarity of the fingerprints | ||||
|      * similarity from 0~1, which 0 means no similar feature is found and 1 means in average there is at least one match in every frame | ||||
|      *  | ||||
|      * @return fingerprints similarity | ||||
|      */ | ||||
|     public float getSimilarity() { | ||||
|         return similarity; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Set the similarity of the fingerprints | ||||
|      *  | ||||
|      * @param similarity similarity | ||||
|      */ | ||||
|     public void setSimilarity(float similarity) { | ||||
|         this.similarity = similarity; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the similarity score of the fingerprints | ||||
|      * Number of features found in the fingerprints per frame  | ||||
|      *  | ||||
|      * @return fingerprints similarity score | ||||
|      */ | ||||
|     public float getScore() { | ||||
|         return score; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Set the similarity score of the fingerprints | ||||
|      *  | ||||
|      * @param score | ||||
|      */ | ||||
|     public void setScore(float score) { | ||||
|         this.score = score; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the most similar position in terms of time in second | ||||
|      *  | ||||
|      * @return most similar starting time | ||||
|      */ | ||||
|     public float getsetMostSimilarTimePosition() { | ||||
|         return (float) mostSimilarFramePosition / fingerprintProperties.getNumRobustPointsPerFrame() | ||||
|                         / fingerprintProperties.getFps(); | ||||
|     } | ||||
| } | ||||
| @ -1,134 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class FingerprintSimilarityComputer { | ||||
| 
 | ||||
|     private FingerprintSimilarity fingerprintSimilarity; | ||||
|     byte[] fingerprint1, fingerprint2; | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor, ready to compute the similarity of two fingerprints | ||||
|      *  | ||||
|      * @param fingerprint1 | ||||
|      * @param fingerprint2 | ||||
|      */ | ||||
|     public FingerprintSimilarityComputer(byte[] fingerprint1, byte[] fingerprint2) { | ||||
| 
 | ||||
|         this.fingerprint1 = fingerprint1; | ||||
|         this.fingerprint2 = fingerprint2; | ||||
| 
 | ||||
|         fingerprintSimilarity = new FingerprintSimilarity(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get fingerprint similarity of inout fingerprints | ||||
|      *  | ||||
|      * @return fingerprint similarity object | ||||
|      */ | ||||
|     public FingerprintSimilarity getFingerprintsSimilarity() { | ||||
|         HashMap<Integer, Integer> offset_Score_Table = new HashMap<>(); // offset_Score_Table<offset,count> | ||||
|         int numFrames; | ||||
|         float score = 0; | ||||
|         int mostSimilarFramePosition = Integer.MIN_VALUE; | ||||
| 
 | ||||
|         // one frame may contain several points, use the shorter one be the denominator | ||||
|         if (fingerprint1.length > fingerprint2.length) { | ||||
|             numFrames = FingerprintManager.getNumFrames(fingerprint2); | ||||
|         } else { | ||||
|             numFrames = FingerprintManager.getNumFrames(fingerprint1); | ||||
|         } | ||||
| 
 | ||||
|         // get the pairs | ||||
|         PairManager pairManager = new PairManager(); | ||||
|         HashMap<Integer, List<Integer>> this_Pair_PositionList_Table = | ||||
|                         pairManager.getPair_PositionList_Table(fingerprint1); | ||||
|         HashMap<Integer, List<Integer>> compareWave_Pair_PositionList_Table = | ||||
|                         pairManager.getPair_PositionList_Table(fingerprint2); | ||||
| 
 | ||||
|         for (Integer compareWaveHashNumber : compareWave_Pair_PositionList_Table.keySet()) { | ||||
|             // if the compareWaveHashNumber doesn't exist in both tables, no need to compare | ||||
|             if (!this_Pair_PositionList_Table.containsKey(compareWaveHashNumber) | ||||
|                             || !compareWave_Pair_PositionList_Table.containsKey(compareWaveHashNumber)) { | ||||
|                 continue; | ||||
|             } | ||||
| 
 | ||||
|             // for each compare hash number, get the positions | ||||
|             List<Integer> wavePositionList = this_Pair_PositionList_Table.get(compareWaveHashNumber); | ||||
|             List<Integer> compareWavePositionList = compareWave_Pair_PositionList_Table.get(compareWaveHashNumber); | ||||
| 
 | ||||
|             for (Integer thisPosition : wavePositionList) { | ||||
|                 for (Integer compareWavePosition : compareWavePositionList) { | ||||
|                     int offset = thisPosition - compareWavePosition; | ||||
|                     if (offset_Score_Table.containsKey(offset)) { | ||||
|                         offset_Score_Table.put(offset, offset_Score_Table.get(offset) + 1); | ||||
|                     } else { | ||||
|                         offset_Score_Table.put(offset, 1); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // map rank | ||||
|         MapRank mapRank = new MapRankInteger(offset_Score_Table, false); | ||||
| 
 | ||||
|         // get the most similar positions and scores	 | ||||
|         List<Integer> orderedKeyList = mapRank.getOrderedKeyList(100, true); | ||||
|         if (orderedKeyList.size() > 0) { | ||||
|             int key = orderedKeyList.get(0); | ||||
|             // get the highest score position | ||||
|             mostSimilarFramePosition = key; | ||||
|             score = offset_Score_Table.get(key); | ||||
| 
 | ||||
|             // accumulate the scores from neighbours | ||||
|             if (offset_Score_Table.containsKey(key - 1)) { | ||||
|                 score += offset_Score_Table.get(key - 1) / 2; | ||||
|             } | ||||
|             if (offset_Score_Table.containsKey(key + 1)) { | ||||
|                 score += offset_Score_Table.get(key + 1) / 2; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         /* | ||||
|         Iterator<Integer> orderedKeyListIterator=orderedKeyList.iterator(); | ||||
|         while (orderedKeyListIterator.hasNext()){ | ||||
|         	int offset=orderedKeyListIterator.next();					 | ||||
|         	System.out.println(offset+": "+offset_Score_Table.get(offset)); | ||||
|         } | ||||
|         */ | ||||
| 
 | ||||
|         score /= numFrames; | ||||
|         float similarity = score; | ||||
|         // similarity >1 means in average there is at least one match in every frame | ||||
|         if (similarity > 1) { | ||||
|             similarity = 1; | ||||
|         } | ||||
| 
 | ||||
|         fingerprintSimilarity.setMostSimilarFramePosition(mostSimilarFramePosition); | ||||
|         fingerprintSimilarity.setScore(score); | ||||
|         fingerprintSimilarity.setSimilarity(similarity); | ||||
| 
 | ||||
|         return fingerprintSimilarity; | ||||
|     } | ||||
| } | ||||
| @ -1,27 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| public interface MapRank { | ||||
|     public List getOrderedKeyList(int numKeys, boolean sharpLimit); | ||||
| } | ||||
| @ -1,179 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| import java.util.*; | ||||
| import java.util.Map.Entry; | ||||
| 
 | ||||
| public class MapRankDouble implements MapRank { | ||||
| 
 | ||||
|     private Map map; | ||||
|     private boolean acsending = true; | ||||
| 
 | ||||
|     public MapRankDouble(Map<?, Double> map, boolean acsending) { | ||||
|         this.map = map; | ||||
|         this.acsending = acsending; | ||||
|     } | ||||
| 
 | ||||
|     public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value | ||||
| 
 | ||||
|         Set mapEntrySet = map.entrySet(); | ||||
|         List keyList = new LinkedList(); | ||||
| 
 | ||||
|         // if the numKeys is larger than map size, limit it | ||||
|         if (numKeys > map.size()) { | ||||
|             numKeys = map.size(); | ||||
|         } | ||||
|         // end if the numKeys is larger than map size, limit it | ||||
| 
 | ||||
|         if (map.size() > 0) { | ||||
|             double[] array = new double[map.size()]; | ||||
|             int count = 0; | ||||
| 
 | ||||
|             // get the pass values | ||||
|             Iterator<Entry> mapIterator = mapEntrySet.iterator(); | ||||
|             while (mapIterator.hasNext()) { | ||||
|                 Entry entry = mapIterator.next(); | ||||
|                 array[count++] = (Double) entry.getValue(); | ||||
|             } | ||||
|             // end get the pass values | ||||
| 
 | ||||
|             int targetindex; | ||||
|             if (acsending) { | ||||
|                 targetindex = numKeys; | ||||
|             } else { | ||||
|                 targetindex = array.length - numKeys; | ||||
|             } | ||||
| 
 | ||||
|             double passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element | ||||
|             // get the passed keys and values | ||||
|             Map passedMap = new HashMap(); | ||||
|             List<Double> valueList = new LinkedList<Double>(); | ||||
|             mapIterator = mapEntrySet.iterator(); | ||||
| 
 | ||||
|             while (mapIterator.hasNext()) { | ||||
|                 Entry entry = mapIterator.next(); | ||||
|                 double value = (Double) entry.getValue(); | ||||
|                 if ((acsending && value <= passValue) || (!acsending && value >= passValue)) { | ||||
|                     passedMap.put(entry.getKey(), value); | ||||
|                     valueList.add(value); | ||||
|                 } | ||||
|             } | ||||
|             // end get the passed keys and values | ||||
| 
 | ||||
|             // sort the value list | ||||
|             Double[] listArr = new Double[valueList.size()]; | ||||
|             valueList.toArray(listArr); | ||||
|             Arrays.sort(listArr); | ||||
|             // end sort the value list | ||||
| 
 | ||||
|             // get the list of keys | ||||
|             int resultCount = 0; | ||||
|             int index; | ||||
|             if (acsending) { | ||||
|                 index = 0; | ||||
|             } else { | ||||
|                 index = listArr.length - 1; | ||||
|             } | ||||
| 
 | ||||
|             if (!sharpLimit) { | ||||
|                 numKeys = listArr.length; | ||||
|             } | ||||
| 
 | ||||
|             while (true) { | ||||
|                 double targetValue = (Double) listArr[index]; | ||||
|                 Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator(); | ||||
|                 while (passedMapIterator.hasNext()) { | ||||
|                     Entry entry = passedMapIterator.next(); | ||||
|                     if ((Double) entry.getValue() == targetValue) { | ||||
|                         keyList.add(entry.getKey()); | ||||
|                         passedMapIterator.remove(); | ||||
|                         resultCount++; | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|                 if (acsending) { | ||||
|                     index++; | ||||
|                 } else { | ||||
|                     index--; | ||||
|                 } | ||||
| 
 | ||||
|                 if (resultCount >= numKeys) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             // end get the list of keys | ||||
|         } | ||||
| 
 | ||||
|         return keyList; | ||||
|     } | ||||
| 
 | ||||
|     private double getOrderedValue(double[] array, int index) { | ||||
|         locate(array, 0, array.length - 1, index); | ||||
|         return array[index]; | ||||
|     } | ||||
| 
 | ||||
|     // sort the partitions by quick sort, and locate the target index | ||||
|     private void locate(double[] array, int left, int right, int index) { | ||||
| 
 | ||||
|         int mid = (left + right) / 2; | ||||
|         //System.out.println(left+" to "+right+" ("+mid+")"); | ||||
| 
 | ||||
|         if (right == left) { | ||||
|             //System.out.println("* "+array[targetIndex]); | ||||
|             //result=array[targetIndex]; | ||||
|             return; | ||||
|         } | ||||
| 
 | ||||
|         if (left < right) { | ||||
|             double s = array[mid]; | ||||
|             int i = left - 1; | ||||
|             int j = right + 1; | ||||
| 
 | ||||
|             while (true) { | ||||
|                 while (array[++i] < s); | ||||
|                 while (array[--j] > s); | ||||
|                 if (i >= j) | ||||
|                     break; | ||||
|                 swap(array, i, j); | ||||
|             } | ||||
| 
 | ||||
|             //System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); | ||||
| 
 | ||||
|             if (i > index) { | ||||
|                 // the target index in the left partition | ||||
|                 //System.out.println("left partition"); | ||||
|                 locate(array, left, i - 1, index); | ||||
|             } else { | ||||
|                 // the target index in the right partition | ||||
|                 //System.out.println("right partition"); | ||||
|                 locate(array, j + 1, right, index); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private void swap(double[] array, int i, int j) { | ||||
|         double t = array[i]; | ||||
|         array[i] = array[j]; | ||||
|         array[j] = t; | ||||
|     } | ||||
| } | ||||
| @ -1,179 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| import java.util.*; | ||||
| import java.util.Map.Entry; | ||||
| 
 | ||||
| public class MapRankInteger implements MapRank { | ||||
| 
 | ||||
|     private Map map; | ||||
|     private boolean acsending = true; | ||||
| 
 | ||||
|     public MapRankInteger(Map<?, Integer> map, boolean acsending) { | ||||
|         this.map = map; | ||||
|         this.acsending = acsending; | ||||
|     } | ||||
| 
 | ||||
|     public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value | ||||
| 
 | ||||
|         Set mapEntrySet = map.entrySet(); | ||||
|         List keyList = new LinkedList(); | ||||
| 
 | ||||
|         // if the numKeys is larger than map size, limit it | ||||
|         if (numKeys > map.size()) { | ||||
|             numKeys = map.size(); | ||||
|         } | ||||
|         // end if the numKeys is larger than map size, limit it | ||||
| 
 | ||||
|         if (map.size() > 0) { | ||||
|             int[] array = new int[map.size()]; | ||||
|             int count = 0; | ||||
| 
 | ||||
|             // get the pass values | ||||
|             Iterator<Entry> mapIterator = mapEntrySet.iterator(); | ||||
|             while (mapIterator.hasNext()) { | ||||
|                 Entry entry = mapIterator.next(); | ||||
|                 array[count++] = (Integer) entry.getValue(); | ||||
|             } | ||||
|             // end get the pass values | ||||
| 
 | ||||
|             int targetindex; | ||||
|             if (acsending) { | ||||
|                 targetindex = numKeys; | ||||
|             } else { | ||||
|                 targetindex = array.length - numKeys; | ||||
|             } | ||||
| 
 | ||||
|             int passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element | ||||
|             // get the passed keys and values | ||||
|             Map passedMap = new HashMap(); | ||||
|             List<Integer> valueList = new LinkedList<Integer>(); | ||||
|             mapIterator = mapEntrySet.iterator(); | ||||
| 
 | ||||
|             while (mapIterator.hasNext()) { | ||||
|                 Entry entry = mapIterator.next(); | ||||
|                 int value = (Integer) entry.getValue(); | ||||
|                 if ((acsending && value <= passValue) || (!acsending && value >= passValue)) { | ||||
|                     passedMap.put(entry.getKey(), value); | ||||
|                     valueList.add(value); | ||||
|                 } | ||||
|             } | ||||
|             // end get the passed keys and values | ||||
| 
 | ||||
|             // sort the value list | ||||
|             Integer[] listArr = new Integer[valueList.size()]; | ||||
|             valueList.toArray(listArr); | ||||
|             Arrays.sort(listArr); | ||||
|             // end sort the value list | ||||
| 
 | ||||
|             // get the list of keys | ||||
|             int resultCount = 0; | ||||
|             int index; | ||||
|             if (acsending) { | ||||
|                 index = 0; | ||||
|             } else { | ||||
|                 index = listArr.length - 1; | ||||
|             } | ||||
| 
 | ||||
|             if (!sharpLimit) { | ||||
|                 numKeys = listArr.length; | ||||
|             } | ||||
| 
 | ||||
|             while (true) { | ||||
|                 int targetValue = (Integer) listArr[index]; | ||||
|                 Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator(); | ||||
|                 while (passedMapIterator.hasNext()) { | ||||
|                     Entry entry = passedMapIterator.next(); | ||||
|                     if ((Integer) entry.getValue() == targetValue) { | ||||
|                         keyList.add(entry.getKey()); | ||||
|                         passedMapIterator.remove(); | ||||
|                         resultCount++; | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|                 if (acsending) { | ||||
|                     index++; | ||||
|                 } else { | ||||
|                     index--; | ||||
|                 } | ||||
| 
 | ||||
|                 if (resultCount >= numKeys) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             // end get the list of keys | ||||
|         } | ||||
| 
 | ||||
|         return keyList; | ||||
|     } | ||||
| 
 | ||||
|     private int getOrderedValue(int[] array, int index) { | ||||
|         locate(array, 0, array.length - 1, index); | ||||
|         return array[index]; | ||||
|     } | ||||
| 
 | ||||
|     // sort the partitions by quick sort, and locate the target index | ||||
|     private void locate(int[] array, int left, int right, int index) { | ||||
| 
 | ||||
|         int mid = (left + right) / 2; | ||||
|         //System.out.println(left+" to "+right+" ("+mid+")"); | ||||
| 
 | ||||
|         if (right == left) { | ||||
|             //System.out.println("* "+array[targetIndex]); | ||||
|             //result=array[targetIndex]; | ||||
|             return; | ||||
|         } | ||||
| 
 | ||||
|         if (left < right) { | ||||
|             int s = array[mid]; | ||||
|             int i = left - 1; | ||||
|             int j = right + 1; | ||||
| 
 | ||||
|             while (true) { | ||||
|                 while (array[++i] < s); | ||||
|                 while (array[--j] > s); | ||||
|                 if (i >= j) | ||||
|                     break; | ||||
|                 swap(array, i, j); | ||||
|             } | ||||
| 
 | ||||
|             //System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); | ||||
| 
 | ||||
|             if (i > index) { | ||||
|                 // the target index in the left partition | ||||
|                 //System.out.println("left partition"); | ||||
|                 locate(array, left, i - 1, index); | ||||
|             } else { | ||||
|                 // the target index in the right partition | ||||
|                 //System.out.println("right partition"); | ||||
|                 locate(array, j + 1, right, index); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private void swap(int[] array, int i, int j) { | ||||
|         int t = array[i]; | ||||
|         array[i] = array[j]; | ||||
|         array[j] = t; | ||||
|     } | ||||
| } | ||||
| @ -1,230 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.audio.properties.FingerprintProperties; | ||||
| 
 | ||||
| import java.util.HashMap; | ||||
| import java.util.LinkedList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class PairManager { | ||||
| 
 | ||||
|     FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); | ||||
|     private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); | ||||
|     private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks; | ||||
|     private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength(); | ||||
|     private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval(); | ||||
|     private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance(); | ||||
|     private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits(); | ||||
| 
 | ||||
|     private int maxPairs; | ||||
|     private boolean isReferencePairing; | ||||
|     private HashMap<Integer, Boolean> stopPairTable = new HashMap<>(); | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor | ||||
|      */ | ||||
|     public PairManager() { | ||||
|         maxPairs = fingerprintProperties.getRefMaxActivePairs(); | ||||
|         isReferencePairing = true; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructor, number of pairs of robust points depends on the parameter isReferencePairing | ||||
|      * no. of pairs of reference and sample can be different due to environmental influence of source   | ||||
|      * @param isReferencePairing | ||||
|      */ | ||||
|     public PairManager(boolean isReferencePairing) { | ||||
|         if (isReferencePairing) { | ||||
|             maxPairs = fingerprintProperties.getRefMaxActivePairs(); | ||||
|         } else { | ||||
|             maxPairs = fingerprintProperties.getSampleMaxActivePairs(); | ||||
|         } | ||||
|         this.isReferencePairing = isReferencePairing; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get a pair-positionList table | ||||
|      * It's a hash map which the key is the hashed pair, and the value is list of positions | ||||
|      * That means the table stores the positions which have the same hashed pair | ||||
|      *  | ||||
|      * @param fingerprint	fingerprint bytes | ||||
|      * @return pair-positionList HashMap | ||||
|      */ | ||||
|     public HashMap<Integer, List<Integer>> getPair_PositionList_Table(byte[] fingerprint) { | ||||
| 
 | ||||
|         List<int[]> pairPositionList = getPairPositionList(fingerprint); | ||||
| 
 | ||||
|         // table to store pair:pos,pos,pos,...;pair2:pos,pos,pos,.... | ||||
|         HashMap<Integer, List<Integer>> pair_positionList_table = new HashMap<>(); | ||||
| 
 | ||||
|         // get all pair_positions from list, use a table to collect the data group by pair hashcode | ||||
|         for (int[] pair_position : pairPositionList) { | ||||
|             //System.out.println(pair_position[0]+","+pair_position[1]); | ||||
| 
 | ||||
|             // group by pair-hashcode, i.e.: <pair,List<position>> | ||||
|             if (pair_positionList_table.containsKey(pair_position[0])) { | ||||
|                 pair_positionList_table.get(pair_position[0]).add(pair_position[1]); | ||||
|             } else { | ||||
|                 List<Integer> positionList = new LinkedList<>(); | ||||
|                 positionList.add(pair_position[1]); | ||||
|                 pair_positionList_table.put(pair_position[0], positionList); | ||||
|             } | ||||
|             // end group by pair-hashcode, i.e.: <pair,List<position>> | ||||
|         } | ||||
|         // end get all pair_positions from list, use a table to collect the data group by pair hashcode | ||||
| 
 | ||||
|         return pair_positionList_table; | ||||
|     } | ||||
| 
 | ||||
|     // this return list contains: int[0]=pair_hashcode, int[1]=position | ||||
|     private List<int[]> getPairPositionList(byte[] fingerprint) { | ||||
| 
 | ||||
|         int numFrames = FingerprintManager.getNumFrames(fingerprint); | ||||
| 
 | ||||
|         // table for paired frames | ||||
|         byte[] pairedFrameTable = new byte[numFrames / anchorPointsIntervalLength + 1]; // each second has numAnchorPointsPerSecond pairs only | ||||
|         // end table for paired frames | ||||
| 
 | ||||
|         List<int[]> pairList = new LinkedList<>(); | ||||
|         List<int[]> sortedCoordinateList = getSortedCoordinateList(fingerprint); | ||||
| 
 | ||||
|         for (int[] anchorPoint : sortedCoordinateList) { | ||||
|             int anchorX = anchorPoint[0]; | ||||
|             int anchorY = anchorPoint[1]; | ||||
|             int numPairs = 0; | ||||
| 
 | ||||
|             for (int[] aSortedCoordinateList : sortedCoordinateList) { | ||||
| 
 | ||||
|                 if (numPairs >= maxPairs) { | ||||
|                     break; | ||||
|                 } | ||||
| 
 | ||||
|                 if (isReferencePairing && pairedFrameTable[anchorX | ||||
|                                 / anchorPointsIntervalLength] >= numAnchorPointsPerInterval) { | ||||
|                     break; | ||||
|                 } | ||||
| 
 | ||||
|                 int targetX = aSortedCoordinateList[0]; | ||||
|                 int targetY = aSortedCoordinateList[1]; | ||||
| 
 | ||||
|                 if (anchorX == targetX && anchorY == targetY) { | ||||
|                     continue; | ||||
|                 } | ||||
| 
 | ||||
|                 // pair up the points | ||||
|                 int x1, y1, x2, y2; // x2 always >= x1 | ||||
|                 if (targetX >= anchorX) { | ||||
|                     x2 = targetX; | ||||
|                     y2 = targetY; | ||||
|                     x1 = anchorX; | ||||
|                     y1 = anchorY; | ||||
|                 } else { | ||||
|                     x2 = anchorX; | ||||
|                     y2 = anchorY; | ||||
|                     x1 = targetX; | ||||
|                     y1 = targetY; | ||||
|                 } | ||||
| 
 | ||||
|                 // check target zone | ||||
|                 if ((x2 - x1) > maxTargetZoneDistance) { | ||||
|                     continue; | ||||
|                 } | ||||
|                 // end check target zone | ||||
| 
 | ||||
|                 // check filter bank zone | ||||
|                 if (!(y1 / bandwidthPerBank == y2 / bandwidthPerBank)) { | ||||
|                     continue; // same filter bank should have equal value | ||||
|                 } | ||||
|                 // end check filter bank zone | ||||
| 
 | ||||
|                 int pairHashcode = (x2 - x1) * numFrequencyUnits * numFrequencyUnits + y2 * numFrequencyUnits + y1; | ||||
| 
 | ||||
|                 // stop list applied on sample pairing only | ||||
|                 if (!isReferencePairing && stopPairTable.containsKey(pairHashcode)) { | ||||
|                     numPairs++; // no reservation | ||||
|                     continue; // escape this point only | ||||
|                 } | ||||
|                 // end stop list applied on sample pairing only | ||||
| 
 | ||||
|                 // pass all rules | ||||
|                 pairList.add(new int[] {pairHashcode, anchorX}); | ||||
|                 pairedFrameTable[anchorX / anchorPointsIntervalLength]++; | ||||
|                 numPairs++; | ||||
|                 // end pair up the points | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return pairList; | ||||
|     } | ||||
| 
 | ||||
|     private List<int[]> getSortedCoordinateList(byte[] fingerprint) { | ||||
|         // each point data is 8 bytes  | ||||
|         // first 2 bytes is x | ||||
|         // next 2 bytes is y | ||||
|         // next 4 bytes is intensity | ||||
| 
 | ||||
|         // get all intensities | ||||
|         int numCoordinates = fingerprint.length / 8; | ||||
|         int[] intensities = new int[numCoordinates]; | ||||
|         for (int i = 0; i < numCoordinates; i++) { | ||||
|             int pointer = i * 8 + 4; | ||||
|             int intensity = (fingerprint[pointer] & 0xff) << 24 | (fingerprint[pointer + 1] & 0xff) << 16 | ||||
|                             | (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff); | ||||
|             intensities[i] = intensity; | ||||
|         } | ||||
| 
 | ||||
|         QuickSortIndexPreserved quicksort = new QuickSortIndexPreserved(intensities); | ||||
|         int[] sortIndexes = quicksort.getSortIndexes(); | ||||
| 
 | ||||
|         List<int[]> sortedCoordinateList = new LinkedList<>(); | ||||
|         for (int i = sortIndexes.length - 1; i >= 0; i--) { | ||||
|             int pointer = sortIndexes[i] * 8; | ||||
|             int x = (fingerprint[pointer] & 0xff) << 8 | (fingerprint[pointer + 1] & 0xff); | ||||
|             int y = (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff); | ||||
|             sortedCoordinateList.add(new int[] {x, y}); | ||||
|         } | ||||
|         return sortedCoordinateList; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Convert hashed pair to bytes | ||||
|      *  | ||||
|      * @param pairHashcode hashed pair | ||||
|      * @return byte array | ||||
|      */ | ||||
|     public static byte[] pairHashcodeToBytes(int pairHashcode) { | ||||
|         return new byte[] {(byte) (pairHashcode >> 8), (byte) pairHashcode}; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Convert bytes to hased pair | ||||
|      *  | ||||
|      * @param pairBytes | ||||
|      * @return hashed pair | ||||
|      */ | ||||
|     public static int pairBytesToHashcode(byte[] pairBytes) { | ||||
|         return (pairBytes[0] & 0xFF) << 8 | (pairBytes[1] & 0xFF); | ||||
|     } | ||||
| } | ||||
| @ -1,25 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| public abstract class QuickSort { | ||||
|     public abstract int[] getSortIndexes(); | ||||
| } | ||||
| @ -1,79 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| public class QuickSortDouble extends QuickSort { | ||||
| 
 | ||||
|     private int[] indexes; | ||||
|     private double[] array; | ||||
| 
 | ||||
|     public QuickSortDouble(double[] array) { | ||||
|         this.array = array; | ||||
|         indexes = new int[array.length]; | ||||
|         for (int i = 0; i < indexes.length; i++) { | ||||
|             indexes[i] = i; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public int[] getSortIndexes() { | ||||
|         sort(); | ||||
|         return indexes; | ||||
|     } | ||||
| 
 | ||||
|     private void sort() { | ||||
|         quicksort(array, indexes, 0, indexes.length - 1); | ||||
|     } | ||||
| 
 | ||||
|     // quicksort a[left] to a[right] | ||||
|     private void quicksort(double[] a, int[] indexes, int left, int right) { | ||||
|         if (right <= left) | ||||
|             return; | ||||
|         int i = partition(a, indexes, left, right); | ||||
|         quicksort(a, indexes, left, i - 1); | ||||
|         quicksort(a, indexes, i + 1, right); | ||||
|     } | ||||
| 
 | ||||
|     // partition a[left] to a[right], assumes left < right | ||||
|     private int partition(double[] a, int[] indexes, int left, int right) { | ||||
|         int i = left - 1; | ||||
|         int j = right; | ||||
|         while (true) { | ||||
|             while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel | ||||
|             while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap | ||||
|                 if (j == left) | ||||
|                     break; // don't go out-of-bounds | ||||
|             } | ||||
|             if (i >= j) | ||||
|                 break; // check if pointers cross | ||||
|             swap(a, indexes, i, j); // swap two elements into place | ||||
|         } | ||||
|         swap(a, indexes, i, right); // swap with partition element | ||||
|         return i; | ||||
|     } | ||||
| 
 | ||||
|     // exchange a[i] and a[j] | ||||
|     private void swap(double[] a, int[] indexes, int i, int j) { | ||||
|         int swap = indexes[i]; | ||||
|         indexes[i] = indexes[j]; | ||||
|         indexes[j] = swap; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,43 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| public class QuickSortIndexPreserved { | ||||
| 
 | ||||
|     private QuickSort quickSort; | ||||
| 
 | ||||
|     public QuickSortIndexPreserved(int[] array) { | ||||
|         quickSort = new QuickSortInteger(array); | ||||
|     } | ||||
| 
 | ||||
|     public QuickSortIndexPreserved(double[] array) { | ||||
|         quickSort = new QuickSortDouble(array); | ||||
|     } | ||||
| 
 | ||||
|     public QuickSortIndexPreserved(short[] array) { | ||||
|         quickSort = new QuickSortShort(array); | ||||
|     } | ||||
| 
 | ||||
|     public int[] getSortIndexes() { | ||||
|         return quickSort.getSortIndexes(); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,79 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| public class QuickSortInteger extends QuickSort { | ||||
| 
 | ||||
|     private int[] indexes; | ||||
|     private int[] array; | ||||
| 
 | ||||
|     public QuickSortInteger(int[] array) { | ||||
|         this.array = array; | ||||
|         indexes = new int[array.length]; | ||||
|         for (int i = 0; i < indexes.length; i++) { | ||||
|             indexes[i] = i; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public int[] getSortIndexes() { | ||||
|         sort(); | ||||
|         return indexes; | ||||
|     } | ||||
| 
 | ||||
|     private void sort() { | ||||
|         quicksort(array, indexes, 0, indexes.length - 1); | ||||
|     } | ||||
| 
 | ||||
|     // quicksort a[left] to a[right] | ||||
|     private void quicksort(int[] a, int[] indexes, int left, int right) { | ||||
|         if (right <= left) | ||||
|             return; | ||||
|         int i = partition(a, indexes, left, right); | ||||
|         quicksort(a, indexes, left, i - 1); | ||||
|         quicksort(a, indexes, i + 1, right); | ||||
|     } | ||||
| 
 | ||||
|     // partition a[left] to a[right], assumes left < right | ||||
|     private int partition(int[] a, int[] indexes, int left, int right) { | ||||
|         int i = left - 1; | ||||
|         int j = right; | ||||
|         while (true) { | ||||
|             while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel | ||||
|             while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap | ||||
|                 if (j == left) | ||||
|                     break; // don't go out-of-bounds | ||||
|             } | ||||
|             if (i >= j) | ||||
|                 break; // check if pointers cross | ||||
|             swap(a, indexes, i, j); // swap two elements into place | ||||
|         } | ||||
|         swap(a, indexes, i, right); // swap with partition element | ||||
|         return i; | ||||
|     } | ||||
| 
 | ||||
|     // exchange a[i] and a[j] | ||||
|     private void swap(int[] a, int[] indexes, int i, int j) { | ||||
|         int swap = indexes[i]; | ||||
|         indexes[i] = indexes[j]; | ||||
|         indexes[j] = swap; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,79 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.fingerprint; | ||||
| 
 | ||||
| public class QuickSortShort extends QuickSort { | ||||
| 
 | ||||
|     private int[] indexes; | ||||
|     private short[] array; | ||||
| 
 | ||||
|     public QuickSortShort(short[] array) { | ||||
|         this.array = array; | ||||
|         indexes = new int[array.length]; | ||||
|         for (int i = 0; i < indexes.length; i++) { | ||||
|             indexes[i] = i; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public int[] getSortIndexes() { | ||||
|         sort(); | ||||
|         return indexes; | ||||
|     } | ||||
| 
 | ||||
|     private void sort() { | ||||
|         quicksort(array, indexes, 0, indexes.length - 1); | ||||
|     } | ||||
| 
 | ||||
|     // quicksort a[left] to a[right] | ||||
|     private void quicksort(short[] a, int[] indexes, int left, int right) { | ||||
|         if (right <= left) | ||||
|             return; | ||||
|         int i = partition(a, indexes, left, right); | ||||
|         quicksort(a, indexes, left, i - 1); | ||||
|         quicksort(a, indexes, i + 1, right); | ||||
|     } | ||||
| 
 | ||||
|     // partition a[left] to a[right], assumes left < right | ||||
|     private int partition(short[] a, int[] indexes, int left, int right) { | ||||
|         int i = left - 1; | ||||
|         int j = right; | ||||
|         while (true) { | ||||
|             while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel | ||||
|             while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap | ||||
|                 if (j == left) | ||||
|                     break; // don't go out-of-bounds | ||||
|             } | ||||
|             if (i >= j) | ||||
|                 break; // check if pointers cross | ||||
|             swap(a, indexes, i, j); // swap two elements into place | ||||
|         } | ||||
|         swap(a, indexes, i, right); // swap with partition element | ||||
|         return i; | ||||
|     } | ||||
| 
 | ||||
|     // exchange a[i] and a[j] | ||||
|     private void swap(short[] a, int[] indexes, int i, int j) { | ||||
|         int swap = indexes[i]; | ||||
|         indexes[i] = indexes[j]; | ||||
|         indexes[j] = swap; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,51 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.formats.input; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.formats.input.BaseInputFormat; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.audio.recordreader.WavFileRecordReader; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  * Wave file input format | ||||
|  * | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class WavInputFormat extends BaseInputFormat { | ||||
|     @Override | ||||
|     public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { | ||||
|         return createReader(split); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public RecordReader createReader(InputSplit split) throws IOException, InterruptedException { | ||||
|         RecordReader waveRecordReader = new WavFileRecordReader(); | ||||
|         waveRecordReader.initialize(split); | ||||
|         return waveRecordReader; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,36 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.formats.output; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.exceptions.DataVecException; | ||||
| import org.datavec.api.formats.output.OutputFormat; | ||||
| import org.datavec.api.records.writer.RecordWriter; | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class WaveOutputFormat implements OutputFormat { | ||||
|     @Override | ||||
|     public RecordWriter createWriter(Configuration conf) throws DataVecException { | ||||
|         return null; | ||||
|     } | ||||
| } | ||||
| @ -1,139 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.processor; | ||||
| 
 | ||||
| public class ArrayRankDouble { | ||||
| 
 | ||||
|     /** | ||||
|      * Get the index position of maximum value the given array  | ||||
|      * @param array an array | ||||
|      * @return	index of the max value in array | ||||
|      */ | ||||
|     public int getMaxValueIndex(double[] array) { | ||||
| 
 | ||||
|         int index = 0; | ||||
|         double max = Integer.MIN_VALUE; | ||||
| 
 | ||||
|         for (int i = 0; i < array.length; i++) { | ||||
|             if (array[i] > max) { | ||||
|                 max = array[i]; | ||||
|                 index = i; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return index; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the index position of minimum value in the given array  | ||||
|      * @param array an array | ||||
|      * @return	index of the min value in array | ||||
|      */ | ||||
|     public int getMinValueIndex(double[] array) { | ||||
| 
 | ||||
|         int index = 0; | ||||
|         double min = Integer.MAX_VALUE; | ||||
| 
 | ||||
|         for (int i = 0; i < array.length; i++) { | ||||
|             if (array[i] < min) { | ||||
|                 min = array[i]; | ||||
|                 index = i; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return index; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the n-th value in the array after sorted | ||||
|      * @param array an array | ||||
|      * @param n position in array | ||||
|      * @param ascending	is ascending order or not | ||||
|      * @return value at nth position of array | ||||
|      */ | ||||
|     public double getNthOrderedValue(double[] array, int n, boolean ascending) { | ||||
| 
 | ||||
|         if (n > array.length) { | ||||
|             n = array.length; | ||||
|         } | ||||
| 
 | ||||
|         int targetindex; | ||||
|         if (ascending) { | ||||
|             targetindex = n; | ||||
|         } else { | ||||
|             targetindex = array.length - n; | ||||
|         } | ||||
| 
 | ||||
|         // this value is the value of the numKey-th element | ||||
| 
 | ||||
|         return getOrderedValue(array, targetindex); | ||||
|     } | ||||
| 
 | ||||
|     private double getOrderedValue(double[] array, int index) { | ||||
|         locate(array, 0, array.length - 1, index); | ||||
|         return array[index]; | ||||
|     } | ||||
| 
 | ||||
|     // sort the partitions by quick sort, and locate the target index | ||||
|     private void locate(double[] array, int left, int right, int index) { | ||||
| 
 | ||||
|         int mid = (left + right) / 2; | ||||
|         // System.out.println(left+" to "+right+" ("+mid+")"); | ||||
| 
 | ||||
|         if (right == left) { | ||||
|             // System.out.println("* "+array[targetIndex]); | ||||
|             // result=array[targetIndex]; | ||||
|             return; | ||||
|         } | ||||
| 
 | ||||
|         if (left < right) { | ||||
|             double s = array[mid]; | ||||
|             int i = left - 1; | ||||
|             int j = right + 1; | ||||
| 
 | ||||
|             while (true) { | ||||
|                 while (array[++i] < s); | ||||
|                 while (array[--j] > s); | ||||
|                 if (i >= j) | ||||
|                     break; | ||||
|                 swap(array, i, j); | ||||
|             } | ||||
| 
 | ||||
|             // System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); | ||||
| 
 | ||||
|             if (i > index) { | ||||
|                 // the target index in the left partition | ||||
|                 // System.out.println("left partition"); | ||||
|                 locate(array, left, i - 1, index); | ||||
|             } else { | ||||
|                 // the target index in the right partition | ||||
|                 // System.out.println("right partition"); | ||||
|                 locate(array, j + 1, right, index); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private void swap(double[] array, int i, int j) { | ||||
|         double t = array[i]; | ||||
|         array[i] = array[j]; | ||||
|         array[j] = t; | ||||
|     } | ||||
| } | ||||
| @ -1,28 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.processor; | ||||
| 
 | ||||
| public interface IntensityProcessor { | ||||
| 
 | ||||
|     public void execute(); | ||||
| 
 | ||||
|     public double[][] getIntensities(); | ||||
| } | ||||
| @ -1,48 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.processor; | ||||
| 
 | ||||
| import java.util.LinkedList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class ProcessorChain { | ||||
| 
 | ||||
|     private double[][] intensities; | ||||
|     List<IntensityProcessor> processorList = new LinkedList<IntensityProcessor>(); | ||||
| 
 | ||||
|     public ProcessorChain(double[][] intensities) { | ||||
|         this.intensities = intensities; | ||||
|         RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, 1); | ||||
|         processorList.add(robustProcessor); | ||||
|         process(); | ||||
|     } | ||||
| 
 | ||||
|     private void process() { | ||||
|         for (IntensityProcessor processor : processorList) { | ||||
|             processor.execute(); | ||||
|             intensities = processor.getIntensities(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public double[][] getIntensities() { | ||||
|         return intensities; | ||||
|     } | ||||
| } | ||||
| @ -1,61 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.processor; | ||||
| 
 | ||||
| 
 | ||||
| public class RobustIntensityProcessor implements IntensityProcessor { | ||||
| 
 | ||||
|     private double[][] intensities; | ||||
|     private int numPointsPerFrame; | ||||
| 
 | ||||
|     public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) { | ||||
|         this.intensities = intensities; | ||||
|         this.numPointsPerFrame = numPointsPerFrame; | ||||
|     } | ||||
| 
 | ||||
|     public void execute() { | ||||
| 
 | ||||
|         int numX = intensities.length; | ||||
|         int numY = intensities[0].length; | ||||
|         double[][] processedIntensities = new double[numX][numY]; | ||||
| 
 | ||||
|         for (int i = 0; i < numX; i++) { | ||||
|             double[] tmpArray = new double[numY]; | ||||
|             System.arraycopy(intensities[i], 0, tmpArray, 0, numY); | ||||
| 
 | ||||
|             // pass value is the last some elements in sorted array	 | ||||
|             ArrayRankDouble arrayRankDouble = new ArrayRankDouble(); | ||||
|             double passValue = arrayRankDouble.getNthOrderedValue(tmpArray, numPointsPerFrame, false); | ||||
| 
 | ||||
|             // only passed elements will be assigned a value | ||||
|             for (int j = 0; j < numY; j++) { | ||||
|                 if (intensities[i][j] >= passValue) { | ||||
|                     processedIntensities[i][j] = intensities[i][j]; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         intensities = processedIntensities; | ||||
|     } | ||||
| 
 | ||||
|     public double[][] getIntensities() { | ||||
|         return intensities; | ||||
|     } | ||||
| } | ||||
| @ -1,49 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.processor; | ||||
| 
 | ||||
| import java.util.LinkedList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| 
 | ||||
| public class TopManyPointsProcessorChain { | ||||
| 
 | ||||
|     private double[][] intensities; | ||||
|     List<IntensityProcessor> processorList = new LinkedList<>(); | ||||
| 
 | ||||
|     public TopManyPointsProcessorChain(double[][] intensities, int numPoints) { | ||||
|         this.intensities = intensities; | ||||
|         RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, numPoints); | ||||
|         processorList.add(robustProcessor); | ||||
|         process(); | ||||
|     } | ||||
| 
 | ||||
|     private void process() { | ||||
|         for (IntensityProcessor processor : processorList) { | ||||
|             processor.execute(); | ||||
|             intensities = processor.getIntensities(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public double[][] getIntensities() { | ||||
|         return intensities; | ||||
|     } | ||||
| } | ||||
| @ -1,121 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.properties; | ||||
| 
 | ||||
| public class FingerprintProperties { | ||||
| 
 | ||||
|     protected static FingerprintProperties instance = null; | ||||
| 
 | ||||
|     private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint | ||||
|     private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size | ||||
|     private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ...	32 | ||||
|     private int numFilterBanks = 4; | ||||
| 
 | ||||
|     private int upperBoundedFrequency = 1500; // low pass | ||||
|     private int lowerBoundedFrequency = 400; // high pass | ||||
|     private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps) | ||||
|     private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps | ||||
|     private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually   | ||||
| 
 | ||||
|     private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs | ||||
|     private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip | ||||
|     private int numAnchorPointsPerInterval = 10; | ||||
|     private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second) | ||||
|     private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second) | ||||
| 
 | ||||
|     private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units | ||||
| 
 | ||||
|     public static FingerprintProperties getInstance() { | ||||
|         if (instance == null) { | ||||
|             synchronized (FingerprintProperties.class) { | ||||
|                 if (instance == null) { | ||||
|                     instance = new FingerprintProperties(); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return instance; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumRobustPointsPerFrame() { | ||||
|         return numRobustPointsPerFrame; | ||||
|     } | ||||
| 
 | ||||
|     public int getSampleSizePerFrame() { | ||||
|         return sampleSizePerFrame; | ||||
|     } | ||||
| 
 | ||||
|     public int getOverlapFactor() { | ||||
|         return overlapFactor; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumFilterBanks() { | ||||
|         return numFilterBanks; | ||||
|     } | ||||
| 
 | ||||
|     public int getUpperBoundedFrequency() { | ||||
|         return upperBoundedFrequency; | ||||
|     } | ||||
| 
 | ||||
|     public int getLowerBoundedFrequency() { | ||||
|         return lowerBoundedFrequency; | ||||
|     } | ||||
| 
 | ||||
|     public int getFps() { | ||||
|         return fps; | ||||
|     } | ||||
| 
 | ||||
|     public int getRefMaxActivePairs() { | ||||
|         return refMaxActivePairs; | ||||
|     } | ||||
| 
 | ||||
|     public int getSampleMaxActivePairs() { | ||||
|         return sampleMaxActivePairs; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumAnchorPointsPerInterval() { | ||||
|         return numAnchorPointsPerInterval; | ||||
|     } | ||||
| 
 | ||||
|     public int getAnchorPointsIntervalLength() { | ||||
|         return anchorPointsIntervalLength; | ||||
|     } | ||||
| 
 | ||||
|     public int getMaxTargetZoneDistance() { | ||||
|         return maxTargetZoneDistance; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumFrequencyUnits() { | ||||
|         return numFrequencyUnits; | ||||
|     } | ||||
| 
 | ||||
|     public int getMaxPossiblePairHashcode() { | ||||
|         return maxTargetZoneDistance * numFrequencyUnits * numFrequencyUnits + numFrequencyUnits * numFrequencyUnits | ||||
|                         + numFrequencyUnits; | ||||
|     } | ||||
| 
 | ||||
|     public int getSampleRate() { | ||||
|         return sampleRate; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumFramesInOneSecond() { | ||||
|         return numFramesInOneSecond; | ||||
|     } | ||||
| } | ||||
| @ -1,225 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.recordreader; | ||||
| 
 | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.Record; | ||||
| import org.datavec.api.records.metadata.RecordMetaData; | ||||
| import org.datavec.api.records.reader.BaseRecordReader; | ||||
| import org.datavec.api.split.BaseInputSplit; | ||||
| import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.api.split.InputStreamInputSplit; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import java.io.DataInputStream; | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.net.URI; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.Paths; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.Iterator; | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * Base audio file loader | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public abstract class BaseAudioRecordReader extends BaseRecordReader { | ||||
|     private Iterator<File> iter; | ||||
|     private List<Writable> record; | ||||
|     private boolean hitImage = false; | ||||
|     private boolean appendLabel = false; | ||||
|     private List<String> labels = new ArrayList<>(); | ||||
|     private Configuration conf; | ||||
|     protected InputSplit inputSplit; | ||||
| 
 | ||||
|     public BaseAudioRecordReader() {} | ||||
| 
 | ||||
|     public BaseAudioRecordReader(boolean appendLabel, List<String> labels) { | ||||
|         this.appendLabel = appendLabel; | ||||
|         this.labels = labels; | ||||
|     } | ||||
| 
 | ||||
|     public BaseAudioRecordReader(List<String> labels) { | ||||
|         this.labels = labels; | ||||
|     } | ||||
| 
 | ||||
|     public BaseAudioRecordReader(boolean appendLabel) { | ||||
|         this.appendLabel = appendLabel; | ||||
|     } | ||||
| 
 | ||||
|     protected abstract List<Writable> loadData(File file, InputStream inputStream) throws IOException; | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(InputSplit split) throws IOException, InterruptedException { | ||||
|         inputSplit = split; | ||||
|         if (split instanceof BaseInputSplit) { | ||||
|             URI[] locations = split.locations(); | ||||
|             if (locations != null && locations.length >= 1) { | ||||
|                 if (locations.length > 1) { | ||||
|                     List<File> allFiles = new ArrayList<>(); | ||||
|                     for (URI location : locations) { | ||||
|                         File iter = new File(location); | ||||
|                         if (iter.isDirectory()) { | ||||
|                             Iterator<File> allFiles2 = FileUtils.iterateFiles(iter, null, true); | ||||
|                             while (allFiles2.hasNext()) | ||||
|                                 allFiles.add(allFiles2.next()); | ||||
|                         } | ||||
| 
 | ||||
|                         else | ||||
|                             allFiles.add(iter); | ||||
|                     } | ||||
| 
 | ||||
|                     iter = allFiles.iterator(); | ||||
|                 } else { | ||||
|                     File curr = new File(locations[0]); | ||||
|                     if (curr.isDirectory()) | ||||
|                         iter = FileUtils.iterateFiles(curr, null, true); | ||||
|                     else | ||||
|                         iter = Collections.singletonList(curr).iterator(); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         else if (split instanceof InputStreamInputSplit) { | ||||
|             record = new ArrayList<>(); | ||||
|             InputStreamInputSplit split2 = (InputStreamInputSplit) split; | ||||
|             InputStream is = split2.getIs(); | ||||
|             URI[] locations = split2.locations(); | ||||
|             if (appendLabel) { | ||||
|                 Path path = Paths.get(locations[0]); | ||||
|                 String parent = path.getParent().toString(); | ||||
|                 record.add(new DoubleWritable(labels.indexOf(parent))); | ||||
|             } | ||||
| 
 | ||||
|             is.close(); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { | ||||
|         this.conf = conf; | ||||
|         this.appendLabel = conf.getBoolean(APPEND_LABEL, false); | ||||
|         this.labels = new ArrayList<>(conf.getStringCollection(LABELS)); | ||||
|         initialize(split); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Writable> next() { | ||||
|         if (iter != null) { | ||||
|             File next = iter.next(); | ||||
|             invokeListeners(next); | ||||
|             try { | ||||
|                 return loadData(next, null); | ||||
|             } catch (Exception e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|         } else if (record != null) { | ||||
|             hitImage = true; | ||||
|             return record; | ||||
|         } | ||||
| 
 | ||||
|         throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist"); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean hasNext() { | ||||
|         if (iter != null) { | ||||
|             return iter.hasNext(); | ||||
|         } else if (record != null) { | ||||
|             return !hitImage; | ||||
|         } | ||||
|         throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void close() throws IOException { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setConf(Configuration conf) { | ||||
|         this.conf = conf; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Configuration getConf() { | ||||
|         return conf; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> getLabels() { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void reset() { | ||||
|         if (inputSplit == null) | ||||
|             throw new UnsupportedOperationException("Cannot reset without first initializing"); | ||||
|         try { | ||||
|             initialize(inputSplit); | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException("Error during LineRecordReader reset", e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean resetSupported(){ | ||||
|         if(inputSplit == null){ | ||||
|             return false; | ||||
|         } | ||||
|         return inputSplit.resetSupported(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException { | ||||
|         invokeListeners(uri); | ||||
|         try { | ||||
|             return loadData(null, dataInputStream); | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Record nextRecord() { | ||||
|         return new org.datavec.api.records.impl.Record(next(), null); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { | ||||
|         throw new UnsupportedOperationException("Loading from metadata not yet implemented"); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException { | ||||
|         throw new UnsupportedOperationException("Loading from metadata not yet implemented"); | ||||
|     } | ||||
| } | ||||
| @ -1,76 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.recordreader; | ||||
| 
 | ||||
| import org.bytedeco.javacv.FFmpegFrameGrabber; | ||||
| import org.bytedeco.javacv.Frame; | ||||
| import org.datavec.api.writable.FloatWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.nio.FloatBuffer; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.bytedeco.ffmpeg.global.avutil.AV_SAMPLE_FMT_FLT; | ||||
| 
 | ||||
| /** | ||||
|  * Native audio file loader using FFmpeg. | ||||
|  * | ||||
|  * @author saudet | ||||
|  */ | ||||
| public class NativeAudioRecordReader extends BaseAudioRecordReader { | ||||
| 
 | ||||
|     public NativeAudioRecordReader() {} | ||||
| 
 | ||||
|     public NativeAudioRecordReader(boolean appendLabel, List<String> labels) { | ||||
|         super(appendLabel, labels); | ||||
|     } | ||||
| 
 | ||||
|     public NativeAudioRecordReader(List<String> labels) { | ||||
|         super(labels); | ||||
|     } | ||||
| 
 | ||||
|     public NativeAudioRecordReader(boolean appendLabel) { | ||||
|         super(appendLabel); | ||||
|     } | ||||
| 
 | ||||
|     protected List<Writable> loadData(File file, InputStream inputStream) throws IOException { | ||||
|         List<Writable> ret = new ArrayList<>(); | ||||
|         try (FFmpegFrameGrabber grabber = inputStream != null ? new FFmpegFrameGrabber(inputStream) | ||||
|                         : new FFmpegFrameGrabber(file.getAbsolutePath())) { | ||||
|             grabber.setSampleFormat(AV_SAMPLE_FMT_FLT); | ||||
|             grabber.start(); | ||||
|             Frame frame; | ||||
|             while ((frame = grabber.grab()) != null) { | ||||
|                 while (frame.samples != null && frame.samples[0].hasRemaining()) { | ||||
|                     for (int i = 0; i < frame.samples.length; i++) { | ||||
|                         ret.add(new FloatWritable(((FloatBuffer) frame.samples[i]).get())); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,57 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio.recordreader; | ||||
| 
 | ||||
| import org.datavec.api.util.RecordUtils; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.audio.Wave; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * Wav file loader | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class WavFileRecordReader extends BaseAudioRecordReader { | ||||
| 
 | ||||
|     public WavFileRecordReader() {} | ||||
| 
 | ||||
|     public WavFileRecordReader(boolean appendLabel, List<String> labels) { | ||||
|         super(appendLabel, labels); | ||||
|     } | ||||
| 
 | ||||
|     public WavFileRecordReader(List<String> labels) { | ||||
|         super(labels); | ||||
|     } | ||||
| 
 | ||||
|     public WavFileRecordReader(boolean appendLabel) { | ||||
|         super(appendLabel); | ||||
|     } | ||||
| 
 | ||||
|     protected List<Writable> loadData(File file, InputStream inputStream) throws IOException { | ||||
|         Wave wave = inputStream != null ? new Wave(inputStream) : new Wave(file.getAbsolutePath()); | ||||
|         return RecordUtils.toRecord(wave.getNormalizedAmplitudes()); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,51 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| package org.datavec.audio; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { | ||||
| 
 | ||||
| 	@Override | ||||
| 	public long getTimeoutMilliseconds() { | ||||
| 		return 60000; | ||||
| 	} | ||||
| 
 | ||||
|     @Override | ||||
|     protected Set<Class<?>> getExclusions() { | ||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) | ||||
| 	    return new HashSet<>(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| 	protected String getPackageName() { | ||||
|     	return "org.datavec.audio"; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Class<?> getBaseClass() { | ||||
|     	return BaseND4JTest.class; | ||||
| 	} | ||||
| } | ||||
| @ -1,68 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio; | ||||
| 
 | ||||
| import org.bytedeco.javacv.FFmpegFrameRecorder; | ||||
| import org.bytedeco.javacv.Frame; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.audio.recordreader.NativeAudioRecordReader; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.ShortBuffer; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_VORBIS; | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| /** | ||||
|  * @author saudet | ||||
|  */ | ||||
| public class AudioReaderTest extends BaseND4JTest { | ||||
|     @Ignore | ||||
|     @Test | ||||
|     public void testNativeAudioReader() throws Exception { | ||||
|         File tempFile = File.createTempFile("testNativeAudioReader", ".ogg"); | ||||
|         FFmpegFrameRecorder recorder = new FFmpegFrameRecorder(tempFile, 2); | ||||
|         recorder.setAudioCodec(AV_CODEC_ID_VORBIS); | ||||
|         recorder.setSampleRate(44100); | ||||
|         recorder.start(); | ||||
|         Frame audioFrame = new Frame(); | ||||
|         ShortBuffer audioBuffer = ShortBuffer.allocate(64 * 1024); | ||||
|         audioFrame.sampleRate = 44100; | ||||
|         audioFrame.audioChannels = 2; | ||||
|         audioFrame.samples = new ShortBuffer[] {audioBuffer}; | ||||
|         recorder.record(audioFrame); | ||||
|         recorder.stop(); | ||||
|         recorder.release(); | ||||
| 
 | ||||
|         RecordReader reader = new NativeAudioRecordReader(); | ||||
|         reader.initialize(new FileSplit(tempFile)); | ||||
|         assertTrue(reader.hasNext()); | ||||
|         List<Writable> record = reader.next(); | ||||
|         assertEquals(audioBuffer.limit(), record.size()); | ||||
|     } | ||||
| } | ||||
| @ -1,69 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.audio; | ||||
| 
 | ||||
| import org.datavec.audio.dsp.FastFourierTransform; | ||||
| import org.junit.Assert; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| public class TestFastFourierTransform extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFastFourierTransformComplex() { | ||||
|         FastFourierTransform fft = new FastFourierTransform(); | ||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; | ||||
|         double[] frequencies = fft.getMagnitudes(amplitudes); | ||||
| 
 | ||||
|         Assert.assertEquals(2, frequencies.length); | ||||
|         Assert.assertArrayEquals(new double[] {21.335, 18.513}, frequencies, 0.005); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFastFourierTransformComplexLong() { | ||||
|         FastFourierTransform fft = new FastFourierTransform(); | ||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; | ||||
|         double[] frequencies = fft.getMagnitudes(amplitudes, true); | ||||
| 
 | ||||
|         Assert.assertEquals(4, frequencies.length); | ||||
|         Assert.assertArrayEquals(new double[] {21.335, 18.5132, 14.927, 7.527}, frequencies, 0.005); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFastFourierTransformReal() { | ||||
|         FastFourierTransform fft = new FastFourierTransform(); | ||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; | ||||
|         double[] frequencies = fft.getMagnitudes(amplitudes, false); | ||||
| 
 | ||||
|         Assert.assertEquals(4, frequencies.length); | ||||
|         Assert.assertArrayEquals(new double[] {28.8, 2.107, 14.927, 19.874}, frequencies, 0.005); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFastFourierTransformRealOddSize() { | ||||
|         FastFourierTransform fft = new FastFourierTransform(); | ||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5}; | ||||
|         double[] frequencies = fft.getMagnitudes(amplitudes, false); | ||||
| 
 | ||||
|         Assert.assertEquals(3, frequencies.length); | ||||
|         Assert.assertArrayEquals(new double[] {24.2, 3.861, 16.876}, frequencies, 0.005); | ||||
|     } | ||||
| } | ||||
| @ -1,71 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-data</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-data-codec</artifactId> | ||||
| 
 | ||||
|     <name>datavec-data-codec</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-api</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-data-image</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.jcodec</groupId> | ||||
|             <artifactId>jcodec</artifactId> | ||||
|             <version>0.1.5</version> | ||||
|         </dependency> | ||||
|         <!-- Do not depend on FFmpeg by default due to licensing concerns. --> | ||||
|         <!-- | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>ffmpeg-platform</artifactId> | ||||
|                     <version>${ffmpeg.version}-${javacpp-presets.version}</version> | ||||
|                 </dependency> | ||||
|         --> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,41 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.codec.format.input; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.formats.input.BaseInputFormat; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.codec.reader.CodecRecordReader; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class CodecInputFormat extends BaseInputFormat { | ||||
|     @Override | ||||
|     public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { | ||||
|         RecordReader reader = new CodecRecordReader(); | ||||
|         reader.initialize(conf, split); | ||||
|         return reader; | ||||
|     } | ||||
| } | ||||
| @ -1,144 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.codec.reader; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.SequenceRecord; | ||||
| import org.datavec.api.records.metadata.RecordMetaData; | ||||
| import org.datavec.api.records.metadata.RecordMetaDataURI; | ||||
| import org.datavec.api.records.reader.SequenceRecordReader; | ||||
| import org.datavec.api.records.reader.impl.FileRecordReader; | ||||
| import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import java.io.DataInputStream; | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.net.URI; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public abstract class BaseCodecRecordReader extends FileRecordReader implements SequenceRecordReader { | ||||
|     protected int startFrame = 0; | ||||
|     protected int numFrames = -1; | ||||
|     protected int totalFrames = -1; | ||||
|     protected double framesPerSecond = -1; | ||||
|     protected double videoLength = -1; | ||||
|     protected int rows = 28, cols = 28; | ||||
|     protected boolean ravel = false; | ||||
| 
 | ||||
|     public final static String NAME_SPACE = "org.datavec.codec.reader"; | ||||
|     public final static String ROWS = NAME_SPACE + ".rows"; | ||||
|     public final static String COLUMNS = NAME_SPACE + ".columns"; | ||||
|     public final static String START_FRAME = NAME_SPACE + ".startframe"; | ||||
|     public final static String TOTAL_FRAMES = NAME_SPACE + ".frames"; | ||||
|     public final static String TIME_SLICE = NAME_SPACE + ".time"; | ||||
|     public final static String RAVEL = NAME_SPACE + ".ravel"; | ||||
|     public final static String VIDEO_DURATION = NAME_SPACE + ".duration"; | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public List<List<Writable>> sequenceRecord() { | ||||
|         URI next = locationsIterator.next(); | ||||
| 
 | ||||
|         try (InputStream s = streamCreatorFn.apply(next)){ | ||||
|             return loadData(null, s); | ||||
|         } catch (IOException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException { | ||||
|         return loadData(null, dataInputStream); | ||||
|     } | ||||
| 
 | ||||
|     protected abstract List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException; | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { | ||||
|         setConf(conf); | ||||
|         initialize(split); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Writable> next() { | ||||
|         throw new UnsupportedOperationException("next() not supported for CodecRecordReader (use: sequenceRecord)"); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException { | ||||
|         throw new UnsupportedOperationException("record(URI,DataInputStream) not supported for CodecRecordReader"); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setConf(Configuration conf) { | ||||
|         super.setConf(conf); | ||||
|         startFrame = conf.getInt(START_FRAME, 0); | ||||
|         numFrames = conf.getInt(TOTAL_FRAMES, -1); | ||||
|         rows = conf.getInt(ROWS, 28); | ||||
|         cols = conf.getInt(COLUMNS, 28); | ||||
|         framesPerSecond = conf.getFloat(TIME_SLICE, -1); | ||||
|         videoLength = conf.getFloat(VIDEO_DURATION, -1); | ||||
|         ravel = conf.getBoolean(RAVEL, false); | ||||
|         totalFrames = conf.getInt(TOTAL_FRAMES, -1); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Configuration getConf() { | ||||
|         return super.getConf(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public SequenceRecord nextSequence() { | ||||
|         URI next = locationsIterator.next(); | ||||
| 
 | ||||
|         List<List<Writable>> list; | ||||
|         try (InputStream s = streamCreatorFn.apply(next)){ | ||||
|             list = loadData(null, s); | ||||
|         } catch (IOException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|         return new org.datavec.api.records.impl.SequenceRecord(list, | ||||
|                         new RecordMetaDataURI(next, CodecRecordReader.class)); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException { | ||||
|         return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException { | ||||
|         List<SequenceRecord> out = new ArrayList<>(); | ||||
|         for (RecordMetaData meta : recordMetaDatas) { | ||||
|             try (InputStream s = streamCreatorFn.apply(meta.getURI())){ | ||||
|                 List<List<Writable>> list = loadData(null, s); | ||||
|                 out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta)); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return out; | ||||
|     } | ||||
| } | ||||
| @ -1,138 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.codec.reader; | ||||
| 
 | ||||
| import org.apache.commons.compress.utils.IOUtils; | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.util.ndarray.RecordConverter; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.image.loader.ImageLoader; | ||||
| import org.jcodec.api.FrameGrab; | ||||
| import org.jcodec.api.JCodecException; | ||||
| import org.jcodec.common.ByteBufferSeekableByteChannel; | ||||
| import org.jcodec.common.NIOUtils; | ||||
| import org.jcodec.common.SeekableByteChannel; | ||||
| 
 | ||||
| import java.awt.image.BufferedImage; | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.lang.reflect.Field; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class CodecRecordReader extends BaseCodecRecordReader { | ||||
| 
 | ||||
|     private ImageLoader imageLoader; | ||||
| 
 | ||||
|     @Override | ||||
|     public void setConf(Configuration conf) { | ||||
|         super.setConf(conf); | ||||
|         imageLoader = new ImageLoader(rows, cols); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException { | ||||
|         SeekableByteChannel seekableByteChannel; | ||||
|         if (inputStream != null) { | ||||
|             //Reading video from DataInputStream: Need data from this stream in a SeekableByteChannel | ||||
|             //Approach used here: load entire video into memory -> ByteBufferSeekableByteChanel | ||||
|             byte[] data = IOUtils.toByteArray(inputStream); | ||||
|             ByteBuffer bb = ByteBuffer.wrap(data); | ||||
|             seekableByteChannel = new FixedByteBufferSeekableByteChannel(bb); | ||||
|         } else { | ||||
|             seekableByteChannel = NIOUtils.readableFileChannel(file); | ||||
|         } | ||||
| 
 | ||||
|         List<List<Writable>> record = new ArrayList<>(); | ||||
| 
 | ||||
|         if (numFrames >= 1) { | ||||
|             FrameGrab fg; | ||||
|             try { | ||||
|                 fg = new FrameGrab(seekableByteChannel); | ||||
|                 if (startFrame != 0) | ||||
|                     fg.seekToFramePrecise(startFrame); | ||||
|             } catch (JCodecException e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
| 
 | ||||
|             for (int i = startFrame; i < startFrame + numFrames; i++) { | ||||
|                 try { | ||||
|                     BufferedImage grab = fg.getFrame(); | ||||
|                     if (ravel) | ||||
|                         record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab))); | ||||
|                     else | ||||
|                         record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab))); | ||||
| 
 | ||||
|                 } catch (Exception e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|             if (framesPerSecond < 1) | ||||
|                 throw new IllegalStateException("No frames or frame time intervals specified"); | ||||
| 
 | ||||
| 
 | ||||
|             else { | ||||
|                 for (double i = 0; i < videoLength; i += framesPerSecond) { | ||||
|                     try { | ||||
|                         BufferedImage grab = FrameGrab.getFrame(seekableByteChannel, i); | ||||
|                         if (ravel) | ||||
|                             record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab))); | ||||
|                         else | ||||
|                             record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab))); | ||||
| 
 | ||||
|                     } catch (Exception e) { | ||||
|                         throw new RuntimeException(e); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return record; | ||||
|     } | ||||
| 
 | ||||
|     /** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */ | ||||
|     private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel { | ||||
|         private ByteBuffer backing; | ||||
| 
 | ||||
|         public FixedByteBufferSeekableByteChannel(ByteBuffer backing) { | ||||
|             super(backing); | ||||
|             try { | ||||
|                 Field f = this.getClass().getSuperclass().getDeclaredField("maxPos"); | ||||
|                 f.setAccessible(true); | ||||
|                 f.set(this, backing.limit()); | ||||
|             } catch (Exception e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|             this.backing = backing; | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         public int read(ByteBuffer dst) throws IOException { | ||||
|             if (!backing.hasRemaining()) | ||||
|                 return -1; | ||||
|             return super.read(dst); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,82 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.codec.reader; | ||||
| 
 | ||||
| import org.bytedeco.javacv.FFmpegFrameGrabber; | ||||
| import org.bytedeco.javacv.Frame; | ||||
| import org.bytedeco.javacv.OpenCVFrameConverter; | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.util.ndarray.RecordConverter; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.image.loader.NativeImageLoader; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.InputStream; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class NativeCodecRecordReader extends BaseCodecRecordReader { | ||||
| 
 | ||||
|     private OpenCVFrameConverter.ToMat converter; | ||||
|     private NativeImageLoader imageLoader; | ||||
| 
 | ||||
|     @Override | ||||
|     public void setConf(Configuration conf) { | ||||
|         super.setConf(conf); | ||||
|         converter = new OpenCVFrameConverter.ToMat(); | ||||
|         imageLoader = new NativeImageLoader(rows, cols); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException { | ||||
|         List<List<Writable>> record = new ArrayList<>(); | ||||
| 
 | ||||
|         try (FFmpegFrameGrabber fg = | ||||
|                         inputStream != null ? new FFmpegFrameGrabber(inputStream) : new FFmpegFrameGrabber(file)) { | ||||
|             if (numFrames >= 1) { | ||||
|                 fg.start(); | ||||
|                 if (startFrame != 0) | ||||
|                     fg.setFrameNumber(startFrame); | ||||
| 
 | ||||
|                 for (int i = startFrame; i < startFrame + numFrames; i++) { | ||||
|                     Frame grab = fg.grabImage(); | ||||
|                     record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab)))); | ||||
|                 } | ||||
|             } else { | ||||
|                 if (framesPerSecond < 1) | ||||
|                     throw new IllegalStateException("No frames or frame time intervals specified"); | ||||
|                 else { | ||||
|                     fg.start(); | ||||
| 
 | ||||
|                     for (double i = 0; i < videoLength; i += framesPerSecond) { | ||||
|                         fg.setTimestamp(Math.round(i * 1000000L)); | ||||
|                         Frame grab = fg.grabImage(); | ||||
|                         record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab)))); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return record; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,46 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| package org.datavec.codec.reader; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { | ||||
| 
 | ||||
|     @Override | ||||
|     protected Set<Class<?>> getExclusions() { | ||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) | ||||
| 	    return new HashSet<>(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| 	protected String getPackageName() { | ||||
|     	return "org.datavec.codec.reader"; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Class<?> getBaseClass() { | ||||
|     	return BaseND4JTest.class; | ||||
| 	} | ||||
| } | ||||
| @ -1,212 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.codec.reader; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.SequenceRecord; | ||||
| import org.datavec.api.records.metadata.RecordMetaData; | ||||
| import org.datavec.api.records.reader.SequenceRecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.ArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| 
 | ||||
| import java.io.DataInputStream; | ||||
| import java.io.File; | ||||
| import java.io.FileInputStream; | ||||
| import java.util.Iterator; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class CodecReaderTest { | ||||
|     @Test | ||||
|     public void testCodecReader() throws Exception { | ||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); | ||||
|         SequenceRecordReader reader = new CodecRecordReader(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.set(CodecRecordReader.RAVEL, "true"); | ||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); | ||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); | ||||
|         conf.set(CodecRecordReader.ROWS, "80"); | ||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); | ||||
|         reader.initialize(new FileSplit(file)); | ||||
|         reader.setConf(conf); | ||||
|         assertTrue(reader.hasNext()); | ||||
|         List<List<Writable>> record = reader.sequenceRecord(); | ||||
|         //        System.out.println(record.size()); | ||||
| 
 | ||||
|         Iterator<List<Writable>> it = record.iterator(); | ||||
|         List<Writable> first = it.next(); | ||||
|         //        System.out.println(first); | ||||
| 
 | ||||
|         //Expected size: 80x46x3 | ||||
|         assertEquals(1, first.size()); | ||||
|         assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCodecReaderMeta() throws Exception { | ||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); | ||||
|         SequenceRecordReader reader = new CodecRecordReader(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.set(CodecRecordReader.RAVEL, "true"); | ||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); | ||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); | ||||
|         conf.set(CodecRecordReader.ROWS, "80"); | ||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); | ||||
|         reader.initialize(new FileSplit(file)); | ||||
|         reader.setConf(conf); | ||||
|         assertTrue(reader.hasNext()); | ||||
|         List<List<Writable>> record = reader.sequenceRecord(); | ||||
|         assertEquals(500, record.size()); //500 frames | ||||
| 
 | ||||
|         reader.reset(); | ||||
|         SequenceRecord seqR = reader.nextSequence(); | ||||
|         assertEquals(record, seqR.getSequenceRecord()); | ||||
|         RecordMetaData meta = seqR.getMetaData(); | ||||
|         //        System.out.println(meta); | ||||
|         assertTrue(meta.getURI().toString().endsWith(file.getName())); | ||||
| 
 | ||||
|         SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta); | ||||
|         assertEquals(seqR, fromMeta); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testViaDataInputStream() throws Exception { | ||||
| 
 | ||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); | ||||
|         SequenceRecordReader reader = new CodecRecordReader(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.set(CodecRecordReader.RAVEL, "true"); | ||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); | ||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); | ||||
|         conf.set(CodecRecordReader.ROWS, "80"); | ||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); | ||||
| 
 | ||||
|         Configuration conf2 = new Configuration(conf); | ||||
| 
 | ||||
|         reader.initialize(new FileSplit(file)); | ||||
|         reader.setConf(conf); | ||||
|         assertTrue(reader.hasNext()); | ||||
|         List<List<Writable>> expected = reader.sequenceRecord(); | ||||
| 
 | ||||
| 
 | ||||
|         SequenceRecordReader reader2 = new CodecRecordReader(); | ||||
|         reader2.setConf(conf2); | ||||
| 
 | ||||
|         DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file)); | ||||
|         List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream); | ||||
| 
 | ||||
|         assertEquals(expected, actual); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Ignore | ||||
|     @Test | ||||
|     public void testNativeCodecReader() throws Exception { | ||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); | ||||
|         SequenceRecordReader reader = new NativeCodecRecordReader(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.set(CodecRecordReader.RAVEL, "true"); | ||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); | ||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); | ||||
|         conf.set(CodecRecordReader.ROWS, "80"); | ||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); | ||||
|         reader.initialize(new FileSplit(file)); | ||||
|         reader.setConf(conf); | ||||
|         assertTrue(reader.hasNext()); | ||||
|         List<List<Writable>> record = reader.sequenceRecord(); | ||||
|         //        System.out.println(record.size()); | ||||
| 
 | ||||
|         Iterator<List<Writable>> it = record.iterator(); | ||||
|         List<Writable> first = it.next(); | ||||
|         //        System.out.println(first); | ||||
| 
 | ||||
|         //Expected size: 80x46x3 | ||||
|         assertEquals(1, first.size()); | ||||
|         assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length()); | ||||
|     } | ||||
| 
 | ||||
|     @Ignore | ||||
|     @Test | ||||
|     public void testNativeCodecReaderMeta() throws Exception { | ||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); | ||||
|         SequenceRecordReader reader = new NativeCodecRecordReader(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.set(CodecRecordReader.RAVEL, "true"); | ||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); | ||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); | ||||
|         conf.set(CodecRecordReader.ROWS, "80"); | ||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); | ||||
|         reader.initialize(new FileSplit(file)); | ||||
|         reader.setConf(conf); | ||||
|         assertTrue(reader.hasNext()); | ||||
|         List<List<Writable>> record = reader.sequenceRecord(); | ||||
|         assertEquals(500, record.size()); //500 frames | ||||
| 
 | ||||
|         reader.reset(); | ||||
|         SequenceRecord seqR = reader.nextSequence(); | ||||
|         assertEquals(record, seqR.getSequenceRecord()); | ||||
|         RecordMetaData meta = seqR.getMetaData(); | ||||
|         //        System.out.println(meta); | ||||
|         assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4")); | ||||
| 
 | ||||
|         SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta); | ||||
|         assertEquals(seqR, fromMeta); | ||||
|     } | ||||
| 
 | ||||
|     @Ignore | ||||
|     @Test | ||||
|     public void testNativeViaDataInputStream() throws Exception { | ||||
| 
 | ||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); | ||||
|         SequenceRecordReader reader = new NativeCodecRecordReader(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.set(CodecRecordReader.RAVEL, "true"); | ||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); | ||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); | ||||
|         conf.set(CodecRecordReader.ROWS, "80"); | ||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); | ||||
| 
 | ||||
|         Configuration conf2 = new Configuration(conf); | ||||
| 
 | ||||
|         reader.initialize(new FileSplit(file)); | ||||
|         reader.setConf(conf); | ||||
|         assertTrue(reader.hasNext()); | ||||
|         List<List<Writable>> expected = reader.sequenceRecord(); | ||||
| 
 | ||||
| 
 | ||||
|         SequenceRecordReader reader2 = new NativeCodecRecordReader(); | ||||
|         reader2.setConf(conf2); | ||||
| 
 | ||||
|         DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file)); | ||||
|         List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream); | ||||
| 
 | ||||
|         assertEquals(expected, actual); | ||||
|     } | ||||
| } | ||||
| @ -1,77 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-data</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-data-nlp</artifactId> | ||||
| 
 | ||||
|     <name>datavec-data-nlp</name> | ||||
| 
 | ||||
|     <properties> | ||||
|         <cleartk.version>2.0.0</cleartk.version> | ||||
|     </properties> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-api</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-local</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.apache.commons</groupId> | ||||
|             <artifactId>commons-lang3</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.cleartk</groupId> | ||||
|             <artifactId>cleartk-snowball</artifactId> | ||||
|             <version>${cleartk.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.cleartk</groupId> | ||||
|             <artifactId>cleartk-opennlp-tools</artifactId> | ||||
|             <version>${cleartk.version}</version> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,237 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.annotator; | ||||
| 
 | ||||
| import opennlp.tools.postag.POSModel; | ||||
| import opennlp.tools.postag.POSTaggerME; | ||||
| import opennlp.uima.postag.POSModelResource; | ||||
| import opennlp.uima.postag.POSModelResourceImpl; | ||||
| import opennlp.uima.util.AnnotationComboIterator; | ||||
| import opennlp.uima.util.AnnotationIteratorPair; | ||||
| import opennlp.uima.util.AnnotatorUtil; | ||||
| import opennlp.uima.util.UimaUtil; | ||||
| import org.apache.uima.UimaContext; | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; | ||||
| import org.apache.uima.cas.CAS; | ||||
| import org.apache.uima.cas.Feature; | ||||
| import org.apache.uima.cas.Type; | ||||
| import org.apache.uima.cas.TypeSystem; | ||||
| import org.apache.uima.cas.text.AnnotationFS; | ||||
| import org.apache.uima.fit.component.CasAnnotator_ImplBase; | ||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; | ||||
| import org.apache.uima.fit.factory.ExternalResourceFactory; | ||||
| import org.apache.uima.resource.ResourceAccessException; | ||||
| import org.apache.uima.resource.ResourceInitializationException; | ||||
| import org.apache.uima.util.Level; | ||||
| import org.apache.uima.util.Logger; | ||||
| import org.cleartk.token.type.Sentence; | ||||
| import org.cleartk.token.type.Token; | ||||
| import org.datavec.nlp.movingwindow.Util; | ||||
| 
 | ||||
| import java.util.Iterator; | ||||
| import java.util.LinkedList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| 
 | ||||
| public class PoStagger extends CasAnnotator_ImplBase { | ||||
| 
 | ||||
|     static { | ||||
|         //UIMA logging | ||||
|         Util.disableLogging(); | ||||
|     } | ||||
| 
 | ||||
|     private POSTaggerME posTagger; | ||||
| 
 | ||||
|     private Type sentenceType; | ||||
| 
 | ||||
|     private Type tokenType; | ||||
| 
 | ||||
|     private Feature posFeature; | ||||
| 
 | ||||
|     private Feature probabilityFeature; | ||||
| 
 | ||||
|     private UimaContext context; | ||||
| 
 | ||||
|     private Logger logger; | ||||
| 
 | ||||
|     /** | ||||
|      * Initializes a new instance. | ||||
|      *  | ||||
|      * Note: Use {@link #initialize(org.apache.uima.UimaContext) } to initialize this instance. Not use the | ||||
|      * constructor. | ||||
|      */ | ||||
|     public PoStagger() { | ||||
|         // must not be implemented ! | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Initializes the current instance with the given context. | ||||
|      * | ||||
|      * Note: Do all initialization in this method, do not use the constructor. | ||||
|      */ | ||||
|     @Override | ||||
|     public void initialize(UimaContext context) throws ResourceInitializationException { | ||||
| 
 | ||||
|         super.initialize(context); | ||||
| 
 | ||||
|         this.context = context; | ||||
| 
 | ||||
|         this.logger = context.getLogger(); | ||||
| 
 | ||||
|         if (this.logger.isLoggable(Level.INFO)) { | ||||
|             this.logger.log(Level.INFO, "Initializing the OpenNLP " + "Part of Speech annotator."); | ||||
|         } | ||||
| 
 | ||||
|         POSModel model; | ||||
| 
 | ||||
|         try { | ||||
|             POSModelResource modelResource = (POSModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER); | ||||
| 
 | ||||
|             model = modelResource.getModel(); | ||||
|         } catch (ResourceAccessException e) { | ||||
|             throw new ResourceInitializationException(e); | ||||
|         } | ||||
| 
 | ||||
|         Integer beamSize = AnnotatorUtil.getOptionalIntegerParameter(context, UimaUtil.BEAM_SIZE_PARAMETER); | ||||
| 
 | ||||
|         if (beamSize == null) | ||||
|             beamSize = POSTaggerME.DEFAULT_BEAM_SIZE; | ||||
| 
 | ||||
|         this.posTagger = new POSTaggerME(model, beamSize, 0); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Initializes the type system. | ||||
|      */ | ||||
|     @Override | ||||
|     public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException { | ||||
| 
 | ||||
|         // sentence type | ||||
|         this.sentenceType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem, | ||||
|                         UimaUtil.SENTENCE_TYPE_PARAMETER); | ||||
| 
 | ||||
|         // token type | ||||
|         this.tokenType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem, | ||||
|                         UimaUtil.TOKEN_TYPE_PARAMETER); | ||||
| 
 | ||||
|         // pos feature | ||||
|         this.posFeature = AnnotatorUtil.getRequiredFeatureParameter(this.context, this.tokenType, | ||||
|                         UimaUtil.POS_FEATURE_PARAMETER, CAS.TYPE_NAME_STRING); | ||||
| 
 | ||||
|         this.probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(this.context, this.tokenType, | ||||
|                         UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Performs pos-tagging on the given tcas object. | ||||
|      */ | ||||
|     @Override | ||||
|     public synchronized void process(CAS tcas) { | ||||
| 
 | ||||
|         final AnnotationComboIterator comboIterator = | ||||
|                         new AnnotationComboIterator(tcas, this.sentenceType, this.tokenType); | ||||
| 
 | ||||
|         for (AnnotationIteratorPair annotationIteratorPair : comboIterator) { | ||||
| 
 | ||||
|             final List<AnnotationFS> sentenceTokenAnnotationList = new LinkedList<AnnotationFS>(); | ||||
| 
 | ||||
|             final List<String> sentenceTokenList = new LinkedList<String>(); | ||||
| 
 | ||||
|             for (AnnotationFS tokenAnnotation : annotationIteratorPair.getSubIterator()) { | ||||
| 
 | ||||
|                 sentenceTokenAnnotationList.add(tokenAnnotation); | ||||
| 
 | ||||
|                 sentenceTokenList.add(tokenAnnotation.getCoveredText()); | ||||
|             } | ||||
| 
 | ||||
|             final List<String> posTags = this.posTagger.tag(sentenceTokenList); | ||||
| 
 | ||||
|             double posProbabilities[] = null; | ||||
| 
 | ||||
|             if (this.probabilityFeature != null) { | ||||
|                 posProbabilities = this.posTagger.probs(); | ||||
|             } | ||||
| 
 | ||||
|             final Iterator<String> posTagIterator = posTags.iterator(); | ||||
|             final Iterator<AnnotationFS> sentenceTokenIterator = sentenceTokenAnnotationList.iterator(); | ||||
| 
 | ||||
|             int index = 0; | ||||
|             while (posTagIterator.hasNext() && sentenceTokenIterator.hasNext()) { | ||||
|                 final String posTag = posTagIterator.next(); | ||||
|                 final AnnotationFS tokenAnnotation = sentenceTokenIterator.next(); | ||||
| 
 | ||||
|                 tokenAnnotation.setStringValue(this.posFeature, posTag); | ||||
| 
 | ||||
|                 if (posProbabilities != null) { | ||||
|                     tokenAnnotation.setDoubleValue(this.posFeature, posProbabilities[index]); | ||||
|                 } | ||||
| 
 | ||||
|                 index++; | ||||
|             } | ||||
| 
 | ||||
|             // log tokens with pos | ||||
|             if (this.logger.isLoggable(Level.FINER)) { | ||||
| 
 | ||||
|                 final StringBuilder sentenceWithPos = new StringBuilder(); | ||||
| 
 | ||||
|                 sentenceWithPos.append("\""); | ||||
| 
 | ||||
|                 for (final Iterator<AnnotationFS> it = sentenceTokenAnnotationList.iterator(); it.hasNext();) { | ||||
|                     final AnnotationFS token = it.next(); | ||||
|                     sentenceWithPos.append(token.getCoveredText()); | ||||
|                     sentenceWithPos.append('\\'); | ||||
|                     sentenceWithPos.append(token.getStringValue(this.posFeature)); | ||||
|                     sentenceWithPos.append(' '); | ||||
|                 } | ||||
| 
 | ||||
|                 // delete last whitespace | ||||
|                 if (sentenceWithPos.length() > 1) // not 0 because it contains already the " char | ||||
|                     sentenceWithPos.setLength(sentenceWithPos.length() - 1); | ||||
| 
 | ||||
|                 sentenceWithPos.append("\""); | ||||
| 
 | ||||
|                 this.logger.log(Level.FINER, sentenceWithPos.toString()); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Releases allocated resources. | ||||
|      */ | ||||
|     @Override | ||||
|     public void destroy() { | ||||
|         this.posTagger = null; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException { | ||||
|         String modelPath = String.format("/models/%s-pos-maxent.bin", languageCode); | ||||
|         return AnalysisEngineFactory.createEngineDescription(PoStagger.class, UimaUtil.MODEL_PARAMETER, | ||||
|                         ExternalResourceFactory.createExternalResourceDescription(POSModelResourceImpl.class, | ||||
|                                         PoStagger.class.getResource(modelPath).toString()), | ||||
|                         UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), UimaUtil.TOKEN_TYPE_PARAMETER, | ||||
|                         Token.class.getName(), UimaUtil.POS_FEATURE_PARAMETER, "pos"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,52 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.annotator; | ||||
| 
 | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; | ||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; | ||||
| import org.apache.uima.jcas.JCas; | ||||
| import org.apache.uima.resource.ResourceInitializationException; | ||||
| import org.cleartk.util.ParamUtil; | ||||
| import org.datavec.nlp.movingwindow.Util; | ||||
| 
 | ||||
| public class SentenceAnnotator extends org.cleartk.opennlp.tools.SentenceAnnotator { | ||||
| 
 | ||||
|     static { | ||||
|         //UIMA logging | ||||
|         Util.disableLogging(); | ||||
|     } | ||||
| 
 | ||||
|     public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { | ||||
|         return AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.class, PARAM_SENTENCE_MODEL_PATH, | ||||
|                         ParamUtil.getParameterValue(PARAM_SENTENCE_MODEL_PATH, "/models/en-sent.bin"), | ||||
|                         PARAM_WINDOW_CLASS_NAMES, ParamUtil.getParameterValue(PARAM_WINDOW_CLASS_NAMES, null)); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public synchronized void process(JCas jCas) throws AnalysisEngineProcessException { | ||||
|         super.process(jCas); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,58 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.annotator; | ||||
| 
 | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; | ||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; | ||||
| import org.apache.uima.jcas.JCas; | ||||
| import org.apache.uima.resource.ResourceInitializationException; | ||||
| import org.cleartk.snowball.SnowballStemmer; | ||||
| import org.cleartk.token.type.Token; | ||||
| 
 | ||||
| 
 | ||||
| public class StemmerAnnotator extends SnowballStemmer<Token> { | ||||
| 
 | ||||
|     public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { | ||||
|         return getDescription("English"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public static AnalysisEngineDescription getDescription(String language) throws ResourceInitializationException { | ||||
|         return AnalysisEngineFactory.createEngineDescription(StemmerAnnotator.class, SnowballStemmer.PARAM_STEMMER_NAME, | ||||
|                         language); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @SuppressWarnings("unchecked") | ||||
|     @Override | ||||
|     public synchronized void process(JCas jCas) throws AnalysisEngineProcessException { | ||||
|         super.process(jCas); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void setStem(Token token, String stem) { | ||||
|         token.setStem(stem); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,70 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.annotator; | ||||
| 
 | ||||
| 
 | ||||
| import opennlp.uima.tokenize.TokenizerModelResourceImpl; | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; | ||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; | ||||
| import org.apache.uima.fit.factory.ExternalResourceFactory; | ||||
| import org.apache.uima.resource.ResourceInitializationException; | ||||
| import org.cleartk.opennlp.tools.Tokenizer; | ||||
| import org.cleartk.token.type.Sentence; | ||||
| import org.cleartk.token.type.Token; | ||||
| import org.datavec.nlp.movingwindow.Util; | ||||
| import org.datavec.nlp.tokenization.tokenizer.ConcurrentTokenizer; | ||||
| 
 | ||||
| 
 | ||||
| /** | ||||
|  * Overrides OpenNLP tokenizer to be thread safe | ||||
|  */ | ||||
| public class TokenizerAnnotator extends Tokenizer { | ||||
| 
 | ||||
|     static { | ||||
|         //UIMA logging | ||||
|         Util.disableLogging(); | ||||
|     } | ||||
| 
 | ||||
|     public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException { | ||||
|         String modelPath = String.format("/models/%s-token.bin", languageCode); | ||||
|         return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class, | ||||
|                         opennlp.uima.util.UimaUtil.MODEL_PARAMETER, | ||||
|                         ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class, | ||||
|                                         ConcurrentTokenizer.class.getResource(modelPath).toString()), | ||||
|                         opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), | ||||
|                         opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { | ||||
|         String modelPath = String.format("/models/%s-token.bin", "en"); | ||||
|         return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class, | ||||
|                         opennlp.uima.util.UimaUtil.MODEL_PARAMETER, | ||||
|                         ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class, | ||||
|                                         ConcurrentTokenizer.class.getResource(modelPath).toString()), | ||||
|                         opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), | ||||
|                         opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,41 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.input; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.formats.input.BaseInputFormat; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.nlp.reader.TfidfRecordReader; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class TextInputFormat extends BaseInputFormat { | ||||
|     @Override | ||||
|     public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { | ||||
|         RecordReader reader = new TfidfRecordReader(); | ||||
|         reader.initialize(conf, split); | ||||
|         return reader; | ||||
|     } | ||||
| } | ||||
| @ -1,148 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.metadata; | ||||
| 
 | ||||
| import org.nd4j.common.primitives.Counter; | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.nlp.vectorizer.TextVectorizer; | ||||
| import org.nd4j.common.util.MathUtils; | ||||
| import org.nd4j.common.util.Index; | ||||
| 
 | ||||
| public class DefaultVocabCache implements VocabCache { | ||||
| 
 | ||||
|     private Counter<String> wordFrequencies = new Counter<>(); | ||||
|     private Counter<String> docFrequencies = new Counter<>(); | ||||
|     private int minWordFrequency; | ||||
|     private Index vocabWords = new Index(); | ||||
|     private double numDocs = 0; | ||||
| 
 | ||||
|     /** | ||||
|      * Instantiate with a given min word frequency | ||||
|      * @param minWordFrequency | ||||
|      */ | ||||
|     public DefaultVocabCache(int minWordFrequency) { | ||||
|         this.minWordFrequency = minWordFrequency; | ||||
|     } | ||||
| 
 | ||||
|     /* | ||||
|      * Constructor for use with initialize() | ||||
|      */ | ||||
|     public DefaultVocabCache() { | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void incrementNumDocs(double by) { | ||||
|         numDocs += by; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public double numDocs() { | ||||
|         return numDocs; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String wordAt(int i) { | ||||
|         return vocabWords.get(i).toString(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int wordIndex(String word) { | ||||
|         return vocabWords.indexOf(word); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(Configuration conf) { | ||||
|         minWordFrequency = conf.getInt(TextVectorizer.MIN_WORD_FREQUENCY, 5); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public double wordFrequency(String word) { | ||||
|         return wordFrequencies.getCount(word); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int minWordFrequency() { | ||||
|         return minWordFrequency; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Index vocabWords() { | ||||
|         return vocabWords; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void incrementDocCount(String word) { | ||||
|         incrementDocCount(word, 1.0); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void incrementDocCount(String word, double by) { | ||||
|         docFrequencies.incrementCount(word, by); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void incrementCount(String word) { | ||||
|         incrementCount(word, 1.0); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void incrementCount(String word, double by) { | ||||
|         wordFrequencies.incrementCount(word, by); | ||||
|         if (wordFrequencies.getCount(word) >= minWordFrequency && vocabWords.indexOf(word) < 0) | ||||
|             vocabWords.add(word); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public double idf(String word) { | ||||
|         return docFrequencies.getCount(word); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public double tfidf(String word, double frequency, boolean smoothIdf) { | ||||
|         double tf = tf((int) frequency); | ||||
|         double docFreq = docFrequencies.getCount(word); | ||||
| 
 | ||||
|         double idf = idf(numDocs, docFreq, smoothIdf); | ||||
|         double tfidf = MathUtils.tfidf(tf, idf); | ||||
|         return tfidf; | ||||
|     } | ||||
| 
 | ||||
|     public double idf(double totalDocs, double numTimesWordAppearedInADocument, boolean smooth) { | ||||
|         if(smooth){ | ||||
|             return Math.log((1 + totalDocs) / (1 + numTimesWordAppearedInADocument)) + 1.0; | ||||
|         } else { | ||||
|             return Math.log(totalDocs / numTimesWordAppearedInADocument) + 1.0; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public static double tf(int count) { | ||||
|         return count; | ||||
|     } | ||||
| 
 | ||||
|     public int getMinWordFrequency() { | ||||
|         return minWordFrequency; | ||||
|     } | ||||
| 
 | ||||
|     public void setMinWordFrequency(int minWordFrequency) { | ||||
|         this.minWordFrequency = minWordFrequency; | ||||
|     } | ||||
| } | ||||
| @ -1,121 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.metadata; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.nd4j.common.util.Index; | ||||
| 
 | ||||
| public interface VocabCache { | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Increment the number of documents | ||||
|      * @param by | ||||
|      */ | ||||
|     void incrementNumDocs(double by); | ||||
| 
 | ||||
|     /** | ||||
|      * Number of documents | ||||
|      * @return the number of documents | ||||
|      */ | ||||
|     double numDocs(); | ||||
| 
 | ||||
|     /** | ||||
|      * Returns a word in the vocab at a particular index | ||||
|      * @param i the index to get | ||||
|      * @return the word at that index in the vocab | ||||
|      */ | ||||
|     String wordAt(int i); | ||||
| 
 | ||||
|     int wordIndex(String word); | ||||
| 
 | ||||
|     /** | ||||
|      * Configuration for initializing | ||||
|      * @param conf the configuration to initialize with | ||||
|      */ | ||||
|     void initialize(Configuration conf); | ||||
| 
 | ||||
|     /** | ||||
|      * Get the word frequency for a word | ||||
|      * @param word the word to get frequency for | ||||
|      * @return the frequency for a given word | ||||
|      */ | ||||
|     double wordFrequency(String word); | ||||
| 
 | ||||
|     /** | ||||
|      * The min word frequency | ||||
|      * needed to be included in the vocab | ||||
|      * (default 5) | ||||
|      * @return the min word frequency to | ||||
|      * be included in the vocab | ||||
|      */ | ||||
|     int minWordFrequency(); | ||||
| 
 | ||||
|     /** | ||||
|      * All of the vocab words (ordered) | ||||
|      * note that these are not all the possible tokens | ||||
|      * @return the list of vocab words | ||||
|      */ | ||||
|     Index vocabWords(); | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Increment the doc count for a word by 1 | ||||
|      * @param word the word to increment the count for | ||||
|      */ | ||||
|     void incrementDocCount(String word); | ||||
| 
 | ||||
|     /** | ||||
|      * Increment the document count for a particular word | ||||
|      * @param word the word to increment the count for | ||||
|      * @param by the amount to increment by | ||||
|      */ | ||||
|     void incrementDocCount(String word, double by); | ||||
| 
 | ||||
|     /** | ||||
|      * Increment a word count by 1 | ||||
|      * @param word the word to increment the count for | ||||
|      */ | ||||
|     void incrementCount(String word); | ||||
| 
 | ||||
|     /** | ||||
|      * Increment count for a word | ||||
|      * @param word the word to increment the count for | ||||
|      * @param by the amount to increment by | ||||
|      */ | ||||
|     void incrementCount(String word, double by); | ||||
| 
 | ||||
|     /** | ||||
|      * Number of documents word has occurred in | ||||
|      * @param word the word to get the idf for | ||||
|      */ | ||||
|     double idf(String word); | ||||
| 
 | ||||
|     /** | ||||
|      * Calculate the tfidf of the word given the document frequency | ||||
|      * @param word the word to get frequency for | ||||
|      * @param frequency the frequency | ||||
|      * @return the tfidf for a word | ||||
|      */ | ||||
|     double tfidf(String word, double frequency, boolean smoothIdf); | ||||
| 
 | ||||
| } | ||||
| @ -1,125 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.movingwindow; | ||||
| 
 | ||||
| 
 | ||||
| import org.apache.commons.lang3.StringUtils; | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.common.collection.MultiDimensionalMap; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class ContextLabelRetriever { | ||||
| 
 | ||||
| 
 | ||||
|     private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>"; | ||||
|     private static String END_LABEL = "</([A-Za-z]+|\\d+)>"; | ||||
| 
 | ||||
| 
 | ||||
|     private ContextLabelRetriever() {} | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns a stripped sentence with the indices of words | ||||
|      * with certain kinds of labels. | ||||
|      * | ||||
|      * @param sentence the sentence to process | ||||
|      * @return a pair of a post processed sentence | ||||
|      * with labels stripped and the spans of | ||||
|      * the labels | ||||
|      */ | ||||
|     public static Pair<String, MultiDimensionalMap<Integer, Integer, String>> stringWithLabels(String sentence, | ||||
|                                                        TokenizerFactory tokenizerFactory) { | ||||
|         MultiDimensionalMap<Integer, Integer, String> map = MultiDimensionalMap.newHashBackedMap(); | ||||
|         Tokenizer t = tokenizerFactory.create(sentence); | ||||
|         List<String> currTokens = new ArrayList<>(); | ||||
|         String currLabel = null; | ||||
|         String endLabel = null; | ||||
|         List<Pair<String, List<String>>> tokensWithSameLabel = new ArrayList<>(); | ||||
|         while (t.hasMoreTokens()) { | ||||
|             String token = t.nextToken(); | ||||
|             if (token.matches(BEGIN_LABEL)) { | ||||
|                 currLabel = token; | ||||
| 
 | ||||
|                 //no labels; add these as NONE and begin the new label | ||||
|                 if (!currTokens.isEmpty()) { | ||||
|                     tokensWithSameLabel.add(new Pair<>("NONE", (List<String>) new ArrayList<>(currTokens))); | ||||
|                     currTokens.clear(); | ||||
| 
 | ||||
|                 } | ||||
| 
 | ||||
|             } else if (token.matches(END_LABEL)) { | ||||
|                 if (currLabel == null) | ||||
|                     throw new IllegalStateException("Found an ending label with no matching begin label"); | ||||
|                 endLabel = token; | ||||
|             } else | ||||
|                 currTokens.add(token); | ||||
| 
 | ||||
|             if (currLabel != null && endLabel != null) { | ||||
|                 currLabel = currLabel.replaceAll("[<>/]", ""); | ||||
|                 endLabel = endLabel.replaceAll("[<>/]", ""); | ||||
|                 Preconditions.checkState(!currLabel.isEmpty(), "Current label is empty!"); | ||||
|                 Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!"); | ||||
|                 Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s", | ||||
|                         currLabel, endLabel); | ||||
| 
 | ||||
|                 tokensWithSameLabel.add(new Pair<>(currLabel, (List<String>) new ArrayList<>(currTokens))); | ||||
|                 currTokens.clear(); | ||||
| 
 | ||||
| 
 | ||||
|                 //clear out the tokens | ||||
|                 currLabel = null; | ||||
|                 endLabel = null; | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         //no labels; add these as NONE and begin the new label | ||||
|         if (!currTokens.isEmpty()) { | ||||
|             tokensWithSameLabel.add(new Pair<>("none", (List<String>) new ArrayList<>(currTokens))); | ||||
|             currTokens.clear(); | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         //now join the output | ||||
|         StringBuilder strippedSentence = new StringBuilder(); | ||||
|         for (Pair<String, List<String>> tokensWithLabel : tokensWithSameLabel) { | ||||
|             String joinedSentence = StringUtils.join(tokensWithLabel.getSecond(), " "); | ||||
|             //spaces between separate parts of the sentence | ||||
|             if (!(strippedSentence.length() < 1)) | ||||
|                 strippedSentence.append(" "); | ||||
|             strippedSentence.append(joinedSentence); | ||||
|             int begin = strippedSentence.toString().indexOf(joinedSentence); | ||||
|             int end = begin + joinedSentence.length(); | ||||
|             map.put(begin, end, tokensWithLabel.getFirst()); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         return new Pair<>(strippedSentence.toString(), map); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,60 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.movingwindow; | ||||
| 
 | ||||
| 
 | ||||
| import org.nd4j.common.primitives.Counter; | ||||
| 
 | ||||
| import java.util.List; | ||||
| import java.util.logging.Level; | ||||
| import java.util.logging.Logger; | ||||
| 
 | ||||
| 
 | ||||
| public class Util { | ||||
| 
 | ||||
|     /** | ||||
|      * Returns a thread safe counter | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public static Counter<String> parallelCounter() { | ||||
|         return new Counter<>(); | ||||
|     } | ||||
| 
 | ||||
|     public static boolean matchesAnyStopWord(List<String> stopWords, String word) { | ||||
|         for (String s : stopWords) | ||||
|             if (s.equalsIgnoreCase(word)) | ||||
|                 return true; | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     public static Level disableLogging() { | ||||
|         Logger logger = Logger.getLogger("org.apache.uima"); | ||||
|         while (logger.getLevel() == null) { | ||||
|             logger = logger.getParent(); | ||||
|         } | ||||
|         Level level = logger.getLevel(); | ||||
|         logger.setLevel(Level.OFF); | ||||
|         return level; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,177 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.movingwindow; | ||||
| 
 | ||||
| import org.apache.commons.lang3.StringUtils; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collection; | ||||
| import java.util.List; | ||||
| 
 | ||||
| 
 | ||||
| public class Window implements Serializable { | ||||
|     /** | ||||
|      *  | ||||
|      */ | ||||
|     private static final long serialVersionUID = 6359906393699230579L; | ||||
|     private List<String> words; | ||||
|     private String label = "NONE"; | ||||
|     private boolean beginLabel; | ||||
|     private boolean endLabel; | ||||
|     private int median; | ||||
|     private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>"; | ||||
|     private static String END_LABEL = "</([A-Z]+|\\d+)>"; | ||||
|     private int begin, end; | ||||
| 
 | ||||
|     /** | ||||
|      * Creates a window with a context of size 3 | ||||
|      * @param words a collection of strings of size 3 | ||||
|      */ | ||||
|     public Window(Collection<String> words, int begin, int end) { | ||||
|         this(words, 5, begin, end); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public String asTokens() { | ||||
|         return StringUtils.join(words, " "); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Initialize a window with the given size | ||||
|      * @param words the words to use  | ||||
|      * @param windowSize the size of the window | ||||
|      * @param begin the begin index for the window | ||||
|      * @param end the end index for the window | ||||
|      */ | ||||
|     public Window(Collection<String> words, int windowSize, int begin, int end) { | ||||
|         if (words == null) | ||||
|             throw new IllegalArgumentException("Words must be a list of size 3"); | ||||
| 
 | ||||
|         this.words = new ArrayList<>(words); | ||||
|         int windowSize1 = windowSize; | ||||
|         this.begin = begin; | ||||
|         this.end = end; | ||||
|         initContext(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     private void initContext() { | ||||
|         int median = (int) Math.floor(words.size() / 2); | ||||
|         List<String> begin = words.subList(0, median); | ||||
|         List<String> after = words.subList(median + 1, words.size()); | ||||
| 
 | ||||
| 
 | ||||
|         for (String s : begin) { | ||||
|             if (s.matches(BEGIN_LABEL)) { | ||||
|                 this.label = s.replaceAll("(<|>)", "").replace("/", ""); | ||||
|                 beginLabel = true; | ||||
|             } else if (s.matches(END_LABEL)) { | ||||
|                 endLabel = true; | ||||
|                 this.label = s.replaceAll("(<|>|/)", "").replace("/", ""); | ||||
| 
 | ||||
|             } | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         for (String s1 : after) { | ||||
| 
 | ||||
|             if (s1.matches(BEGIN_LABEL)) { | ||||
|                 this.label = s1.replaceAll("(<|>)", "").replace("/", ""); | ||||
|                 beginLabel = true; | ||||
|             } | ||||
| 
 | ||||
|             if (s1.matches(END_LABEL)) { | ||||
|                 endLabel = true; | ||||
|                 this.label = s1.replaceAll("(<|>)", ""); | ||||
| 
 | ||||
|             } | ||||
|         } | ||||
|         this.median = median; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         return words.toString(); | ||||
|     } | ||||
| 
 | ||||
|     public List<String> getWords() { | ||||
|         return words; | ||||
|     } | ||||
| 
 | ||||
|     public void setWords(List<String> words) { | ||||
|         this.words = words; | ||||
|     } | ||||
| 
 | ||||
|     public String getWord(int i) { | ||||
|         return words.get(i); | ||||
|     } | ||||
| 
 | ||||
|     public String getFocusWord() { | ||||
|         return words.get(median); | ||||
|     } | ||||
| 
 | ||||
|     public boolean isBeginLabel() { | ||||
|         return !label.equals("NONE") && beginLabel; | ||||
|     } | ||||
| 
 | ||||
|     public boolean isEndLabel() { | ||||
|         return !label.equals("NONE") && endLabel; | ||||
|     } | ||||
| 
 | ||||
|     public String getLabel() { | ||||
|         return label.replace("/", ""); | ||||
|     } | ||||
| 
 | ||||
|     public int getWindowSize() { | ||||
|         return words.size(); | ||||
|     } | ||||
| 
 | ||||
|     public int getMedian() { | ||||
|         return median; | ||||
|     } | ||||
| 
 | ||||
|     public void setLabel(String label) { | ||||
|         this.label = label; | ||||
|     } | ||||
| 
 | ||||
|     public int getBegin() { | ||||
|         return begin; | ||||
|     } | ||||
| 
 | ||||
|     public void setBegin(int begin) { | ||||
|         this.begin = begin; | ||||
|     } | ||||
| 
 | ||||
|     public int getEnd() { | ||||
|         return end; | ||||
|     } | ||||
| 
 | ||||
|     public void setEnd(int end) { | ||||
|         this.end = end; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,188 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.movingwindow; | ||||
| 
 | ||||
| 
 | ||||
| import org.apache.commons.lang3.StringUtils; | ||||
| import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; | ||||
| 
 | ||||
| import java.io.InputStream; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.StringTokenizer; | ||||
| 
 | ||||
| public class Windows { | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Constructs a list of window of size windowSize. | ||||
|      * Note that padding for each window is created as well. | ||||
|      * @param words the words to tokenize and construct windows from | ||||
|      * @param windowSize the window size to generate | ||||
|      * @return the list of windows for the tokenized string | ||||
|      */ | ||||
|     public static List<Window> windows(InputStream words, int windowSize) { | ||||
|         Tokenizer tokenizer = new DefaultStreamTokenizer(words); | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         while (tokenizer.hasMoreTokens()) | ||||
|             list.add(tokenizer.nextToken()); | ||||
|         return windows(list, windowSize); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructs a list of window of size windowSize. | ||||
|      * Note that padding for each window is created as well. | ||||
|      * @param words the words to tokenize and construct windows from | ||||
|      * @param tokenizerFactory tokenizer factory to use | ||||
|      * @param windowSize the window size to generate | ||||
|      * @return the list of windows for the tokenized string | ||||
|      */ | ||||
|     public static List<Window> windows(InputStream words, TokenizerFactory tokenizerFactory, int windowSize) { | ||||
|         Tokenizer tokenizer = tokenizerFactory.create(words); | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         while (tokenizer.hasMoreTokens()) | ||||
|             list.add(tokenizer.nextToken()); | ||||
| 
 | ||||
|         if (list.isEmpty()) | ||||
|             throw new IllegalStateException("No tokens found for windows"); | ||||
| 
 | ||||
|         return windows(list, windowSize); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Constructs a list of window of size windowSize. | ||||
|      * Note that padding for each window is created as well. | ||||
|      * @param words the words to tokenize and construct windows from | ||||
|      * @param windowSize the window size to generate | ||||
|      * @return the list of windows for the tokenized string | ||||
|      */ | ||||
|     public static List<Window> windows(String words, int windowSize) { | ||||
|         StringTokenizer tokenizer = new StringTokenizer(words); | ||||
|         List<String> list = new ArrayList<String>(); | ||||
|         while (tokenizer.hasMoreTokens()) | ||||
|             list.add(tokenizer.nextToken()); | ||||
|         return windows(list, windowSize); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructs a list of window of size windowSize. | ||||
|      * Note that padding for each window is created as well. | ||||
|      * @param words the words to tokenize and construct windows from | ||||
|      * @param tokenizerFactory tokenizer factory to use | ||||
|      * @param windowSize the window size to generate | ||||
|      * @return the list of windows for the tokenized string | ||||
|      */ | ||||
|     public static List<Window> windows(String words, TokenizerFactory tokenizerFactory, int windowSize) { | ||||
|         Tokenizer tokenizer = tokenizerFactory.create(words); | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         while (tokenizer.hasMoreTokens()) | ||||
|             list.add(tokenizer.nextToken()); | ||||
| 
 | ||||
|         if (list.isEmpty()) | ||||
|             throw new IllegalStateException("No tokens found for windows"); | ||||
| 
 | ||||
|         return windows(list, windowSize); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Constructs a list of window of size windowSize. | ||||
|      * Note that padding for each window is created as well. | ||||
|      * @param words the words to tokenize and construct windows from | ||||
|      * @return the list of windows for the tokenized string | ||||
|      */ | ||||
|     public static List<Window> windows(String words) { | ||||
|         StringTokenizer tokenizer = new StringTokenizer(words); | ||||
|         List<String> list = new ArrayList<String>(); | ||||
|         while (tokenizer.hasMoreTokens()) | ||||
|             list.add(tokenizer.nextToken()); | ||||
|         return windows(list, 5); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Constructs a list of window of size windowSize. | ||||
|      * Note that padding for each window is created as well. | ||||
|      * @param words the words to tokenize and construct windows from | ||||
|      * @param tokenizerFactory tokenizer factory to use | ||||
|      * @return the list of windows for the tokenized string | ||||
|      */ | ||||
|     public static List<Window> windows(String words, TokenizerFactory tokenizerFactory) { | ||||
|         Tokenizer tokenizer = tokenizerFactory.create(words); | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         while (tokenizer.hasMoreTokens()) | ||||
|             list.add(tokenizer.nextToken()); | ||||
|         return windows(list, 5); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Creates a sliding window from text | ||||
|      * @param windowSize the window size to use | ||||
|      * @param wordPos the position of the word to center | ||||
|      * @param sentence the sentence to createComplex a window for | ||||
|      * @return a window based on the given sentence | ||||
|      */ | ||||
|     public static Window windowForWordInPosition(int windowSize, int wordPos, List<String> sentence) { | ||||
|         List<String> window = new ArrayList<>(); | ||||
|         List<String> onlyTokens = new ArrayList<>(); | ||||
|         int contextSize = (int) Math.floor((windowSize - 1) / 2); | ||||
| 
 | ||||
|         for (int i = wordPos - contextSize; i <= wordPos + contextSize; i++) { | ||||
|             if (i < 0) | ||||
|                 window.add("<s>"); | ||||
|             else if (i >= sentence.size()) | ||||
|                 window.add("</s>"); | ||||
|             else { | ||||
|                 onlyTokens.add(sentence.get(i)); | ||||
|                 window.add(sentence.get(i)); | ||||
| 
 | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         String wholeSentence = StringUtils.join(sentence); | ||||
|         String window2 = StringUtils.join(onlyTokens); | ||||
|         int begin = wholeSentence.indexOf(window2); | ||||
|         int end = begin + window2.length(); | ||||
|         return new Window(window, begin, end); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Constructs a list of window of size windowSize | ||||
|      * @param words the words to  construct windows from | ||||
|      * @return the list of windows for the tokenized string | ||||
|      */ | ||||
|     public static List<Window> windows(List<String> words, int windowSize) { | ||||
| 
 | ||||
|         List<Window> ret = new ArrayList<>(); | ||||
| 
 | ||||
|         for (int i = 0; i < words.size(); i++) | ||||
|             ret.add(windowForWordInPosition(windowSize, i, words)); | ||||
| 
 | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,189 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.reader; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.Record; | ||||
| import org.datavec.api.records.metadata.RecordMetaData; | ||||
| import org.datavec.api.records.metadata.RecordMetaDataURI; | ||||
| import org.datavec.api.records.reader.impl.FileRecordReader; | ||||
| import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.api.vector.Vectorizer; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.nlp.vectorizer.TfidfVectorizer; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import java.util.*; | ||||
| 
 | ||||
| public class TfidfRecordReader extends FileRecordReader { | ||||
|     private TfidfVectorizer tfidfVectorizer; | ||||
|     private List<Record> records = new ArrayList<>(); | ||||
|     private Iterator<Record> recordIter; | ||||
|     private int numFeatures; | ||||
|     private boolean initialized = false; | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(InputSplit split) throws IOException, InterruptedException { | ||||
|         initialize(new Configuration(), split); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { | ||||
|         super.initialize(conf, split); | ||||
|         //train  a new one since it hasn't been specified | ||||
|         if (tfidfVectorizer == null) { | ||||
|             tfidfVectorizer = new TfidfVectorizer(); | ||||
|             tfidfVectorizer.initialize(conf); | ||||
| 
 | ||||
|             //clear out old strings | ||||
|             records.clear(); | ||||
| 
 | ||||
|             INDArray ret = tfidfVectorizer.fitTransform(this, new Vectorizer.RecordCallBack() { | ||||
|                 @Override | ||||
|                 public void onRecord(Record fullRecord) { | ||||
|                     records.add(fullRecord); | ||||
|                 } | ||||
|             }); | ||||
| 
 | ||||
|             //cache the number of features used for each document | ||||
|             numFeatures = ret.columns(); | ||||
|             recordIter = records.iterator(); | ||||
|         } else { | ||||
|             records = new ArrayList<>(); | ||||
| 
 | ||||
|             //the record reader has 2 phases, we are skipping the | ||||
|             //document frequency phase and just using the super() to get the file contents | ||||
|             //and pass it to the already existing vectorizer. | ||||
|             while (super.hasNext()) { | ||||
|                 Record fileContents = super.nextRecord(); | ||||
|                 INDArray transform = tfidfVectorizer.transform(fileContents); | ||||
| 
 | ||||
|                 org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record( | ||||
|                                 new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))), | ||||
|                                 new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class)); | ||||
| 
 | ||||
|                 if (appendLabel) | ||||
|                     record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1)); | ||||
| 
 | ||||
|                 records.add(record); | ||||
|             } | ||||
| 
 | ||||
|             recordIter = records.iterator(); | ||||
|         } | ||||
| 
 | ||||
|         this.initialized = true; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void reset() { | ||||
|         if (inputSplit == null) | ||||
|             throw new UnsupportedOperationException("Cannot reset without first initializing"); | ||||
|         recordIter = records.iterator(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Record nextRecord() { | ||||
|         if (recordIter == null) | ||||
|             return super.nextRecord(); | ||||
|         return recordIter.next(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Writable> next() { | ||||
|         return nextRecord().getRecord(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean hasNext() { | ||||
|         //we aren't done vectorizing yet | ||||
|         if (recordIter == null) | ||||
|             return super.hasNext(); | ||||
|         return recordIter.hasNext(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void close() throws IOException { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setConf(Configuration conf) { | ||||
|         this.conf = conf; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Configuration getConf() { | ||||
|         return conf; | ||||
|     } | ||||
| 
 | ||||
|     public TfidfVectorizer getTfidfVectorizer() { | ||||
|         return tfidfVectorizer; | ||||
|     } | ||||
| 
 | ||||
|     public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) { | ||||
|         if (initialized) { | ||||
|             throw new IllegalArgumentException( | ||||
|                             "Setting TfidfVectorizer after TfidfRecordReader initialization doesn't have an effect"); | ||||
|         } | ||||
|         this.tfidfVectorizer = tfidfVectorizer; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumFeatures() { | ||||
|         return numFeatures; | ||||
|     } | ||||
| 
 | ||||
|     public void shuffle() { | ||||
|         this.shuffle(new Random()); | ||||
|     } | ||||
| 
 | ||||
|     public void shuffle(Random random) { | ||||
|         Collections.shuffle(this.records, random); | ||||
|         this.reset(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { | ||||
|         return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException { | ||||
|         List<Record> out = new ArrayList<>(); | ||||
| 
 | ||||
|         for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) { | ||||
|             INDArray transform = tfidfVectorizer.transform(fileContents); | ||||
| 
 | ||||
|             org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record( | ||||
|                             new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))), | ||||
|                             new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class)); | ||||
| 
 | ||||
|             if (appendLabel) | ||||
|                 record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1)); | ||||
|             out.add(record); | ||||
|         } | ||||
| 
 | ||||
|         return out; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| @ -1,44 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.stopwords; | ||||
| 
 | ||||
| import org.apache.commons.io.IOUtils; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class StopWords { | ||||
| 
 | ||||
|     private static List<String> stopWords; | ||||
| 
 | ||||
|     @SuppressWarnings("unchecked") | ||||
|     public static List<String> getStopWords() { | ||||
| 
 | ||||
|         try { | ||||
|             if (stopWords == null) | ||||
|                 stopWords = IOUtils.readLines(StopWords.class.getResourceAsStream("/stopwords")); | ||||
|         } catch (IOException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|         return stopWords; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,125 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer; | ||||
| 
 | ||||
| import opennlp.tools.tokenize.TokenizerME; | ||||
| import opennlp.tools.tokenize.TokenizerModel; | ||||
| import opennlp.tools.util.Span; | ||||
| import opennlp.uima.tokenize.AbstractTokenizer; | ||||
| import opennlp.uima.tokenize.TokenizerModelResource; | ||||
| import opennlp.uima.util.AnnotatorUtil; | ||||
| import opennlp.uima.util.UimaUtil; | ||||
| import org.apache.uima.UimaContext; | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; | ||||
| import org.apache.uima.cas.CAS; | ||||
| import org.apache.uima.cas.Feature; | ||||
| import org.apache.uima.cas.TypeSystem; | ||||
| import org.apache.uima.cas.text.AnnotationFS; | ||||
| import org.apache.uima.resource.ResourceAccessException; | ||||
| import org.apache.uima.resource.ResourceInitializationException; | ||||
| 
 | ||||
| public class ConcurrentTokenizer extends AbstractTokenizer { | ||||
| 
 | ||||
|     /** | ||||
|      * The OpenNLP tokenizer. | ||||
|      */ | ||||
|     private TokenizerME tokenizer; | ||||
| 
 | ||||
|     private Feature probabilityFeature; | ||||
| 
 | ||||
|     @Override | ||||
|     public synchronized void process(CAS cas) throws AnalysisEngineProcessException { | ||||
|         super.process(cas); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|        * Initializes a new instance. | ||||
|        * | ||||
|        * Note: Use {@link #initialize(UimaContext) } to initialize  | ||||
|        * this instance. Not use the constructor. | ||||
|        */ | ||||
|     public ConcurrentTokenizer() { | ||||
|         super("OpenNLP Tokenizer"); | ||||
| 
 | ||||
|         // must not be implemented ! | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Initializes the current instance with the given context. | ||||
|      *  | ||||
|      * Note: Do all initialization in this method, do not use the constructor. | ||||
|      */ | ||||
|     public void initialize(UimaContext context) throws ResourceInitializationException { | ||||
| 
 | ||||
|         super.initialize(context); | ||||
| 
 | ||||
|         TokenizerModel model; | ||||
| 
 | ||||
|         try { | ||||
|             TokenizerModelResource modelResource = | ||||
|                             (TokenizerModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER); | ||||
| 
 | ||||
|             model = modelResource.getModel(); | ||||
|         } catch (ResourceAccessException e) { | ||||
|             throw new ResourceInitializationException(e); | ||||
|         } | ||||
| 
 | ||||
|         tokenizer = new TokenizerME(model); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Initializes the type system. | ||||
|      */ | ||||
|     public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException { | ||||
| 
 | ||||
|         super.typeSystemInit(typeSystem); | ||||
| 
 | ||||
|         probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(context, tokenType, | ||||
|                         UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     protected Span[] tokenize(CAS cas, AnnotationFS sentence) { | ||||
|         return tokenizer.tokenizePos(sentence.getCoveredText()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) { | ||||
|         // if interest | ||||
|         if (probabilityFeature != null) { | ||||
|             double tokenProbabilties[] = tokenizer.getTokenProbabilities(); | ||||
| 
 | ||||
|             for (int i = 0; i < tokenAnnotations.length; i++) { | ||||
|                 tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Releases allocated resources. | ||||
|      */ | ||||
|     public void destroy() { | ||||
|         // dereference model to allow garbage collection  | ||||
|         tokenizer = null; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| @ -1,107 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer; | ||||
| 
 | ||||
| 
 | ||||
| import java.io.*; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * Tokenizer based on the {@link java.io.StreamTokenizer} | ||||
|  * @author Adam Gibson | ||||
|  * | ||||
|  */ | ||||
| public class DefaultStreamTokenizer implements Tokenizer { | ||||
| 
 | ||||
|     private StreamTokenizer streamTokenizer; | ||||
|     private TokenPreProcess tokenPreProcess; | ||||
| 
 | ||||
| 
 | ||||
|     public DefaultStreamTokenizer(InputStream is) { | ||||
|         Reader r = new BufferedReader(new InputStreamReader(is)); | ||||
|         streamTokenizer = new StreamTokenizer(r); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean hasMoreTokens() { | ||||
|         if (streamTokenizer.ttype != StreamTokenizer.TT_EOF) { | ||||
|             try { | ||||
|                 streamTokenizer.nextToken(); | ||||
|             } catch (IOException e1) { | ||||
|                 throw new RuntimeException(e1); | ||||
|             } | ||||
|         } | ||||
|         return streamTokenizer.ttype != StreamTokenizer.TT_EOF && streamTokenizer.ttype != -1; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int countTokens() { | ||||
|         return getTokens().size(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String nextToken() { | ||||
|         StringBuilder sb = new StringBuilder(); | ||||
| 
 | ||||
| 
 | ||||
|         if (streamTokenizer.ttype == StreamTokenizer.TT_WORD) { | ||||
|             sb.append(streamTokenizer.sval); | ||||
|         } else if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) { | ||||
|             sb.append(streamTokenizer.nval); | ||||
|         } else if (streamTokenizer.ttype == StreamTokenizer.TT_EOL) { | ||||
|             try { | ||||
|                 while (streamTokenizer.ttype == StreamTokenizer.TT_EOL) | ||||
|                     streamTokenizer.nextToken(); | ||||
|             } catch (IOException e) { | ||||
|                 throw new RuntimeException(e); | ||||
| 
 | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         else if (hasMoreTokens()) | ||||
|             return nextToken(); | ||||
| 
 | ||||
| 
 | ||||
|         String ret = sb.toString(); | ||||
| 
 | ||||
|         if (tokenPreProcess != null) | ||||
|             ret = tokenPreProcess.preProcess(ret); | ||||
|         return ret; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> getTokens() { | ||||
|         List<String> tokens = new ArrayList<>(); | ||||
|         while (hasMoreTokens()) { | ||||
|             tokens.add(nextToken()); | ||||
|         } | ||||
|         return tokens; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { | ||||
|         this.tokenPreProcess = tokenPreProcessor; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,75 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.StringTokenizer; | ||||
| 
 | ||||
| /** | ||||
|  * Default tokenizer | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class DefaultTokenizer implements Tokenizer { | ||||
| 
 | ||||
|     public DefaultTokenizer(String tokens) { | ||||
|         tokenizer = new StringTokenizer(tokens); | ||||
|     } | ||||
| 
 | ||||
|     private StringTokenizer tokenizer; | ||||
|     private TokenPreProcess tokenPreProcess; | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean hasMoreTokens() { | ||||
|         return tokenizer.hasMoreTokens(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int countTokens() { | ||||
|         return tokenizer.countTokens(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String nextToken() { | ||||
|         String base = tokenizer.nextToken(); | ||||
|         if (tokenPreProcess != null) | ||||
|             base = tokenPreProcess.preProcess(base); | ||||
|         return base; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> getTokens() { | ||||
|         List<String> tokens = new ArrayList<>(); | ||||
|         while (hasMoreTokens()) { | ||||
|             tokens.add(nextToken()); | ||||
|         } | ||||
|         return tokens; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { | ||||
|         this.tokenPreProcess = tokenPreProcessor; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,136 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer; | ||||
| 
 | ||||
| import org.apache.uima.analysis_engine.AnalysisEngine; | ||||
| import org.apache.uima.cas.CAS; | ||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; | ||||
| import org.apache.uima.fit.util.JCasUtil; | ||||
| import org.cleartk.token.type.Sentence; | ||||
| import org.cleartk.token.type.Token; | ||||
| import org.datavec.nlp.annotator.PoStagger; | ||||
| import org.datavec.nlp.annotator.SentenceAnnotator; | ||||
| import org.datavec.nlp.annotator.StemmerAnnotator; | ||||
| import org.datavec.nlp.annotator.TokenizerAnnotator; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collection; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class PosUimaTokenizer implements Tokenizer { | ||||
| 
 | ||||
|     private static AnalysisEngine engine; | ||||
|     private List<String> tokens; | ||||
|     private Collection<String> allowedPosTags; | ||||
|     private int index; | ||||
|     private static CAS cas; | ||||
| 
 | ||||
|     public PosUimaTokenizer(String tokens, AnalysisEngine engine, Collection<String> allowedPosTags) { | ||||
|         if (engine == null) | ||||
|             PosUimaTokenizer.engine = engine; | ||||
|         this.allowedPosTags = allowedPosTags; | ||||
|         this.tokens = new ArrayList<>(); | ||||
|         try { | ||||
|             if (cas == null) | ||||
|                 cas = engine.newCAS(); | ||||
| 
 | ||||
|             cas.reset(); | ||||
|             cas.setDocumentText(tokens); | ||||
|             PosUimaTokenizer.engine.process(cas); | ||||
|             for (Sentence s : JCasUtil.select(cas.getJCas(), Sentence.class)) { | ||||
|                 for (Token t : JCasUtil.selectCovered(Token.class, s)) { | ||||
|                     //add NONE for each invalid token | ||||
|                     if (valid(t)) | ||||
|                         if (t.getLemma() != null) | ||||
|                             this.tokens.add(t.getLemma()); | ||||
|                         else if (t.getStem() != null) | ||||
|                             this.tokens.add(t.getStem()); | ||||
|                         else | ||||
|                             this.tokens.add(t.getCoveredText()); | ||||
|                     else | ||||
|                         this.tokens.add("NONE"); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     private boolean valid(Token token) { | ||||
|         String check = token.getCoveredText(); | ||||
|         if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>")) | ||||
|             return false; | ||||
|         else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos())) | ||||
|             return false; | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean hasMoreTokens() { | ||||
|         return index < tokens.size(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int countTokens() { | ||||
|         return tokens.size(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String nextToken() { | ||||
|         String ret = tokens.get(index); | ||||
|         index++; | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> getTokens() { | ||||
|         List<String> tokens = new ArrayList<String>(); | ||||
|         while (hasMoreTokens()) { | ||||
|             tokens.add(nextToken()); | ||||
|         } | ||||
|         return tokens; | ||||
|     } | ||||
| 
 | ||||
|     public static AnalysisEngine defaultAnalysisEngine() { | ||||
|         try { | ||||
|             return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription( | ||||
|                             SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(), | ||||
|                             PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English"))); | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { | ||||
|         // TODO Auto-generated method stub | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,34 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer; | ||||
| 
 | ||||
| 
 | ||||
| public interface TokenPreProcess { | ||||
| 
 | ||||
|     /** | ||||
|      * Pre process a token | ||||
|      * @param token the token to pre process | ||||
|      * @return the preprocessed token | ||||
|      */ | ||||
|     String preProcess(String token); | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,61 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| public interface Tokenizer { | ||||
| 
 | ||||
|     /** | ||||
|      * An iterator for tracking whether | ||||
|      * more tokens are left in the iterator not | ||||
|      * @return whether there is anymore tokens | ||||
|      * to iterate over | ||||
|      */ | ||||
|     boolean hasMoreTokens(); | ||||
| 
 | ||||
|     /** | ||||
|      * The number of tokens in the tokenizer | ||||
|      * @return the number of tokens | ||||
|      */ | ||||
|     int countTokens(); | ||||
| 
 | ||||
|     /** | ||||
|      * The next token (word usually) in the string | ||||
|      * @return the next token in the string if any | ||||
|      */ | ||||
|     String nextToken(); | ||||
| 
 | ||||
|     /** | ||||
|      * Returns a list of all the tokens | ||||
|      * @return a list of all the tokens | ||||
|      */ | ||||
|     List<String> getTokens(); | ||||
| 
 | ||||
|     /** | ||||
|      * Set the token pre process | ||||
|      * @param tokenPreProcessor the token pre processor to set | ||||
|      */ | ||||
|     void setTokenPreProcessor(TokenPreProcess tokenPreProcessor); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,123 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer; | ||||
| 
 | ||||
| import org.apache.uima.cas.CAS; | ||||
| import org.apache.uima.fit.util.JCasUtil; | ||||
| import org.cleartk.token.type.Token; | ||||
| import org.datavec.nlp.uima.UimaResource; | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collection; | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * Tokenizer based on the passed in analysis engine | ||||
|  * @author Adam Gibson | ||||
|  * | ||||
|  */ | ||||
| public class UimaTokenizer implements Tokenizer { | ||||
| 
 | ||||
|     private List<String> tokens; | ||||
|     private int index; | ||||
|     private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class); | ||||
|     private boolean checkForLabel; | ||||
|     private TokenPreProcess tokenPreProcessor; | ||||
| 
 | ||||
| 
 | ||||
|     public UimaTokenizer(String tokens, UimaResource resource, boolean checkForLabel) { | ||||
| 
 | ||||
|         this.checkForLabel = checkForLabel; | ||||
|         this.tokens = new ArrayList<>(); | ||||
|         try { | ||||
|             CAS cas = resource.process(tokens); | ||||
| 
 | ||||
|             Collection<Token> tokenList = JCasUtil.select(cas.getJCas(), Token.class); | ||||
| 
 | ||||
|             for (Token t : tokenList) { | ||||
| 
 | ||||
|                 if (!checkForLabel || valid(t.getCoveredText())) | ||||
|                     if (t.getLemma() != null) | ||||
|                         this.tokens.add(t.getLemma()); | ||||
|                     else if (t.getStem() != null) | ||||
|                         this.tokens.add(t.getStem()); | ||||
|                     else | ||||
|                         this.tokens.add(t.getCoveredText()); | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|             resource.release(cas); | ||||
| 
 | ||||
| 
 | ||||
|         } catch (Exception e) { | ||||
|             log.error("",e); | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     private boolean valid(String check) { | ||||
|         if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>")) | ||||
|             return false; | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean hasMoreTokens() { | ||||
|         return index < tokens.size(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int countTokens() { | ||||
|         return tokens.size(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String nextToken() { | ||||
|         String ret = tokens.get(index); | ||||
|         index++; | ||||
|         if (tokenPreProcessor != null) { | ||||
|             ret = tokenPreProcessor.preProcess(ret); | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> getTokens() { | ||||
|         List<String> tokens = new ArrayList<>(); | ||||
|         while (hasMoreTokens()) { | ||||
|             tokens.add(nextToken()); | ||||
|         } | ||||
|         return tokens; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { | ||||
|         this.tokenPreProcessor = tokenPreProcessor; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,47 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer.preprocessor; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| 
 | ||||
| /** | ||||
|  * Gets rid of endings: | ||||
|  * | ||||
|  *    ed,ing, ly, s, . | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class EndingPreProcessor implements TokenPreProcess { | ||||
|     @Override | ||||
|     public String preProcess(String token) { | ||||
|         if (token.endsWith("s") && !token.endsWith("ss")) | ||||
|             token = token.substring(0, token.length() - 1); | ||||
|         if (token.endsWith(".")) | ||||
|             token = token.substring(0, token.length() - 1); | ||||
|         if (token.endsWith("ed")) | ||||
|             token = token.substring(0, token.length() - 2); | ||||
|         if (token.endsWith("ing")) | ||||
|             token = token.substring(0, token.length() - 3); | ||||
|         if (token.endsWith("ly")) | ||||
|             token = token.substring(0, token.length() - 2); | ||||
|         return token; | ||||
|     } | ||||
| } | ||||
| @ -1,30 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizer.preprocessor; | ||||
| 
 | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| 
 | ||||
| public class LowerCasePreProcessor implements TokenPreProcess { | ||||
|     @Override | ||||
|     public String preProcess(String token) { | ||||
|         return token.toLowerCase(); | ||||
|     } | ||||
| } | ||||
| @ -1,60 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizerfactory; | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizer.DefaultTokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| 
 | ||||
| import java.io.InputStream; | ||||
| 
 | ||||
| /** | ||||
|  * Default tokenizer based on string tokenizer or stream tokenizer | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class DefaultTokenizerFactory implements TokenizerFactory { | ||||
| 
 | ||||
|     private TokenPreProcess tokenPreProcess; | ||||
| 
 | ||||
|     @Override | ||||
|     public Tokenizer create(String toTokenize) { | ||||
|         DefaultTokenizer t = new DefaultTokenizer(toTokenize); | ||||
|         t.setTokenPreProcessor(tokenPreProcess); | ||||
|         return t; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Tokenizer create(InputStream toTokenize) { | ||||
|         Tokenizer t = new DefaultStreamTokenizer(toTokenize); | ||||
|         t.setTokenPreProcessor(tokenPreProcess); | ||||
|         return t; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setTokenPreProcessor(TokenPreProcess preProcessor) { | ||||
|         this.tokenPreProcess = preProcessor; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,85 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizerfactory; | ||||
| 
 | ||||
| 
 | ||||
| import org.apache.uima.analysis_engine.AnalysisEngine; | ||||
| import org.datavec.nlp.annotator.PoStagger; | ||||
| import org.datavec.nlp.annotator.SentenceAnnotator; | ||||
| import org.datavec.nlp.annotator.StemmerAnnotator; | ||||
| import org.datavec.nlp.annotator.TokenizerAnnotator; | ||||
| import org.datavec.nlp.tokenization.tokenizer.PosUimaTokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| 
 | ||||
| import java.io.InputStream; | ||||
| import java.util.Collection; | ||||
| 
 | ||||
| import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine; | ||||
| import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription; | ||||
| 
 | ||||
| public class PosUimaTokenizerFactory implements TokenizerFactory { | ||||
| 
 | ||||
|     private AnalysisEngine tokenizer; | ||||
|     private Collection<String> allowedPoSTags; | ||||
|     private TokenPreProcess tokenPreProcess; | ||||
| 
 | ||||
| 
 | ||||
|     public PosUimaTokenizerFactory(Collection<String> allowedPoSTags) { | ||||
|         this(defaultAnalysisEngine(), allowedPoSTags); | ||||
|     } | ||||
| 
 | ||||
|     public PosUimaTokenizerFactory(AnalysisEngine tokenizer, Collection<String> allowedPosTags) { | ||||
|         this.tokenizer = tokenizer; | ||||
|         this.allowedPoSTags = allowedPosTags; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public static AnalysisEngine defaultAnalysisEngine() { | ||||
|         try { | ||||
|             return createEngine(createEngineDescription(SentenceAnnotator.getDescription(), | ||||
|                             TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"), | ||||
|                             StemmerAnnotator.getDescription("English"))); | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public Tokenizer create(String toTokenize) { | ||||
|         PosUimaTokenizer t = new PosUimaTokenizer(toTokenize, tokenizer, allowedPoSTags); | ||||
|         t.setTokenPreProcessor(tokenPreProcess); | ||||
|         return t; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Tokenizer create(InputStream toTokenize) { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setTokenPreProcessor(TokenPreProcess preProcessor) { | ||||
|         this.tokenPreProcess = preProcessor; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,62 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizerfactory; | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| import org.nd4j.shade.jackson.annotation.JsonTypeInfo; | ||||
| 
 | ||||
| import java.io.InputStream; | ||||
| 
 | ||||
| /** | ||||
|  * Generates a tokenizer for a given string | ||||
|  * @author Adam Gibson | ||||
|  * | ||||
|  */ | ||||
| @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") | ||||
| public interface TokenizerFactory { | ||||
| 
 | ||||
|     /** | ||||
|      * The tokenizer to createComplex | ||||
|      * @param toTokenize the string to createComplex the tokenizer with | ||||
|      * @return the new tokenizer | ||||
|      */ | ||||
|     Tokenizer create(String toTokenize); | ||||
| 
 | ||||
|     /** | ||||
|      * Create a tokenizer based on an input stream | ||||
|      * @param toTokenize | ||||
|      * @return | ||||
|      */ | ||||
|     Tokenizer create(InputStream toTokenize); | ||||
| 
 | ||||
|     /** | ||||
|      * Sets a token pre processor to be used | ||||
|      * with every tokenizer | ||||
|      * @param preProcessor the token pre processor to use | ||||
|      */ | ||||
|     void setTokenPreProcessor(TokenPreProcess preProcessor); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,138 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.tokenization.tokenizerfactory; | ||||
| 
 | ||||
| import org.apache.uima.analysis_engine.AnalysisEngine; | ||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; | ||||
| import org.apache.uima.resource.ResourceInitializationException; | ||||
| import org.datavec.nlp.annotator.SentenceAnnotator; | ||||
| import org.datavec.nlp.annotator.TokenizerAnnotator; | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizer.UimaTokenizer; | ||||
| import org.datavec.nlp.uima.UimaResource; | ||||
| 
 | ||||
| import java.io.InputStream; | ||||
| 
 | ||||
| 
 | ||||
| /** | ||||
|  * Uses a uima {@link AnalysisEngine} to  | ||||
|  * tokenize text. | ||||
|  * | ||||
|  * | ||||
|  * @author Adam Gibson | ||||
|  * | ||||
|  */ | ||||
| public class UimaTokenizerFactory implements TokenizerFactory { | ||||
| 
 | ||||
| 
 | ||||
|     private UimaResource uimaResource; | ||||
|     private boolean checkForLabel; | ||||
|     private static AnalysisEngine defaultAnalysisEngine; | ||||
|     private TokenPreProcess preProcess; | ||||
| 
 | ||||
|     public UimaTokenizerFactory() throws ResourceInitializationException { | ||||
|         this(defaultAnalysisEngine(), true); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public UimaTokenizerFactory(UimaResource resource) { | ||||
|         this(resource, true); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public UimaTokenizerFactory(AnalysisEngine tokenizer) { | ||||
|         this(tokenizer, true); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public UimaTokenizerFactory(UimaResource resource, boolean checkForLabel) { | ||||
|         this.uimaResource = resource; | ||||
|         this.checkForLabel = checkForLabel; | ||||
|     } | ||||
| 
 | ||||
|     public UimaTokenizerFactory(boolean checkForLabel) throws ResourceInitializationException { | ||||
|         this(defaultAnalysisEngine(), checkForLabel); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public UimaTokenizerFactory(AnalysisEngine tokenizer, boolean checkForLabel) { | ||||
|         super(); | ||||
|         this.checkForLabel = checkForLabel; | ||||
|         try { | ||||
|             this.uimaResource = new UimaResource(tokenizer); | ||||
| 
 | ||||
| 
 | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public Tokenizer create(String toTokenize) { | ||||
|         if (toTokenize == null || toTokenize.isEmpty()) | ||||
|             throw new IllegalArgumentException("Unable to proceed; on sentence to tokenize"); | ||||
|         Tokenizer ret = new UimaTokenizer(toTokenize, uimaResource, checkForLabel); | ||||
|         ret.setTokenPreProcessor(preProcess); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public UimaResource getUimaResource() { | ||||
|         return uimaResource; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Creates a tokenization,/stemming pipeline | ||||
|      * @return a tokenization/stemming pipeline | ||||
|      */ | ||||
|     public static AnalysisEngine defaultAnalysisEngine() { | ||||
|         try { | ||||
|             if (defaultAnalysisEngine == null) | ||||
| 
 | ||||
|                 defaultAnalysisEngine = AnalysisEngineFactory.createEngine( | ||||
|                                 AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.getDescription(), | ||||
|                                                 TokenizerAnnotator.getDescription())); | ||||
| 
 | ||||
|             return defaultAnalysisEngine; | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public Tokenizer create(InputStream toTokenize) { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setTokenPreProcessor(TokenPreProcess preProcessor) { | ||||
|         this.preProcess = preProcessor; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,69 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| import org.datavec.api.transform.Transform; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.shade.jackson.annotation.JsonTypeInfo; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") | ||||
| public interface BagOfWordsTransform extends Transform { | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * The output shape of the transform (usually 1 x number of words) | ||||
|      * @return | ||||
|      */ | ||||
|     long[] outputShape(); | ||||
| 
 | ||||
|     /** | ||||
|      * The vocab words in the transform. | ||||
|      * This is the words that were accumulated | ||||
|      * when building a vocabulary. | ||||
|      * (This is generally associated with some form of | ||||
|      * mininmum words frequency scanning to build a vocab | ||||
|      * you then map on to a list of vocab words as a list) | ||||
|      * @return the vocab words for the transform | ||||
|      */ | ||||
|     List<String> vocabWords(); | ||||
| 
 | ||||
|     /** | ||||
|      * Transform for a list of tokens | ||||
|      * that are objects. This is to allow loose | ||||
|      * typing for tokens that are unique (non string) | ||||
|      * @param tokens the token objects to transform | ||||
|      * @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array) | ||||
|      */ | ||||
|     INDArray transformFromObject(List<List<Object>> tokens); | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Transform for a list of tokens | ||||
|      * that are {@link Writable} (Generally {@link org.datavec.api.writable.Text} | ||||
|      * @param tokens the token objects to transform | ||||
|      * @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array) | ||||
|      */ | ||||
|     INDArray transformFrom(List<List<Writable>> tokens); | ||||
| 
 | ||||
| } | ||||
| @ -1,24 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| public class BaseWordMapTransform { | ||||
| } | ||||
| @ -1,146 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import lombok.EqualsAndHashCode; | ||||
| import org.datavec.api.transform.metadata.ColumnMetaData; | ||||
| import org.datavec.api.transform.metadata.NDArrayMetaData; | ||||
| import org.datavec.api.transform.transform.BaseColumnTransform; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.shade.jackson.annotation.JsonCreator; | ||||
| import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; | ||||
| import org.nd4j.shade.jackson.annotation.JsonInclude; | ||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; | ||||
| 
 | ||||
| import java.util.Collections; | ||||
| import java.util.HashSet; | ||||
| import java.util.List; | ||||
| import java.util.Set; | ||||
| 
 | ||||
| @Data | ||||
| @EqualsAndHashCode(callSuper = true) | ||||
| @JsonInclude(JsonInclude.Include.NON_NULL) | ||||
| @JsonIgnoreProperties({"gazeteer"}) | ||||
| public class GazeteerTransform extends BaseColumnTransform implements BagOfWordsTransform { | ||||
| 
 | ||||
|     private String newColumnName; | ||||
|     private List<String> wordList; | ||||
|     private Set<String> gazeteer; | ||||
| 
 | ||||
|     @JsonCreator | ||||
|     public GazeteerTransform(@JsonProperty("columnName") String columnName, | ||||
|                              @JsonProperty("newColumnName")String newColumnName, | ||||
|                              @JsonProperty("wordList") List<String> wordList) { | ||||
|         super(columnName); | ||||
|         this.newColumnName = newColumnName; | ||||
|         this.wordList = wordList; | ||||
|         this.gazeteer = new HashSet<>(wordList); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { | ||||
|         return new NDArrayMetaData(newName,new long[]{wordList.size()}); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Writable map(Writable columnWritable) { | ||||
|        throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Object mapSequence(Object sequence) { | ||||
|         List<List<Object>> sequenceInput = (List<List<Object>>) sequence; | ||||
|         INDArray ret = Nd4j.create(DataType.FLOAT, wordList.size()); | ||||
| 
 | ||||
|         for(List<Object> list : sequenceInput) { | ||||
|             for(Object token : list) { | ||||
|                 String s = token.toString(); | ||||
|                 if(gazeteer.contains(s)) { | ||||
|                     ret.putScalar(wordList.indexOf(s),1); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public List<List<Writable>> mapSequence(List<List<Writable>> sequence) { | ||||
|         INDArray arr = (INDArray) mapSequence((Object) sequence); | ||||
|         return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(arr))); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         return newColumnName; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Object map(Object input) { | ||||
|         return gazeteer.contains(input.toString()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String outputColumnName() { | ||||
|         return newColumnName; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String[] outputColumnNames() { | ||||
|         return new String[]{newColumnName}; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String[] columnNames() { | ||||
|         return new String[]{columnName()}; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String columnName() { | ||||
|         return columnName; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long[] outputShape() { | ||||
|         return new long[]{wordList.size()}; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> vocabWords() { | ||||
|         return wordList; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray transformFromObject(List<List<Object>> tokens) { | ||||
|         return (INDArray) mapSequence(tokens); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray transformFrom(List<List<Writable>> tokens) { | ||||
|         return (INDArray) mapSequence((Object) tokens); | ||||
|     } | ||||
| } | ||||
| @ -1,150 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| import org.datavec.api.transform.metadata.ColumnMetaData; | ||||
| import org.datavec.api.transform.metadata.NDArrayMetaData; | ||||
| import org.datavec.api.transform.transform.BaseColumnTransform; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.list.NDArrayList; | ||||
| import org.nd4j.shade.jackson.annotation.JsonCreator; | ||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; | ||||
| 
 | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform { | ||||
| 
 | ||||
|     private BagOfWordsTransform[] transforms; | ||||
|     private String newColumnName; | ||||
|     private List<String> vocabWords; | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param columnName | ||||
|      * @param transforms | ||||
|      * @param newColumnName | ||||
|      */ | ||||
|     @JsonCreator | ||||
|     public MultiNlpTransform(@JsonProperty("columnName") String columnName, | ||||
|                              @JsonProperty("transforms") BagOfWordsTransform[] transforms, | ||||
|                              @JsonProperty("newColumnName") String newColumnName) { | ||||
|         super(columnName); | ||||
|         this.transforms = transforms; | ||||
|         this.vocabWords = transforms[0].vocabWords(); | ||||
|         if(transforms.length > 1) { | ||||
|             for(int i = 1; i < transforms.length; i++) { | ||||
|                 if(!transforms[i].vocabWords().equals(vocabWords)) { | ||||
|                     throw new IllegalArgumentException("Vocab words not consistent across transforms!"); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         this.newColumnName = newColumnName; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Object mapSequence(Object sequence) { | ||||
|         NDArrayList ndArrayList = new NDArrayList(); | ||||
|         for(BagOfWordsTransform bagofWordsTransform : transforms) { | ||||
|             ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List<List<Object>>) sequence))); | ||||
|         } | ||||
| 
 | ||||
|         return ndArrayList.array(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<List<Writable>> mapSequence(List<List<Writable>> sequence) { | ||||
|      return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(transformFrom(sequence)))); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { | ||||
|         return new NDArrayMetaData(newName,outputShape()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Writable map(Writable columnWritable) { | ||||
|         throw new UnsupportedOperationException("Only able to add for time series"); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         return newColumnName; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Object map(Object input) { | ||||
|         throw new UnsupportedOperationException("Only able to add for time series"); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long[] outputShape() { | ||||
|         long[] ret = new long[transforms[0].outputShape().length]; | ||||
|         int validatedRank = transforms[0].outputShape().length; | ||||
|         for(int i = 1; i < transforms.length; i++) { | ||||
|             if(transforms[i].outputShape().length != validatedRank) { | ||||
|                 throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank); | ||||
|             } | ||||
|         } | ||||
|         for(int i = 0; i < transforms.length; i++) { | ||||
|             for(int j = 0; j < validatedRank; j++) | ||||
|             ret[j] += transforms[i].outputShape()[j]; | ||||
|         } | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> vocabWords() { | ||||
|         return vocabWords; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray transformFromObject(List<List<Object>> tokens) { | ||||
|         NDArrayList ndArrayList = new NDArrayList(); | ||||
|         for(BagOfWordsTransform bagofWordsTransform : transforms) { | ||||
|             INDArray arr2 = bagofWordsTransform.transformFromObject(tokens); | ||||
|             arr2 = arr2.reshape(arr2.length()); | ||||
|             NDArrayList newList = new NDArrayList(arr2,(int) arr2.length()); | ||||
|             ndArrayList.addAll(newList);        } | ||||
| 
 | ||||
|         return ndArrayList.array(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray transformFrom(List<List<Writable>> tokens) { | ||||
|         NDArrayList ndArrayList = new NDArrayList(); | ||||
|         for(BagOfWordsTransform bagofWordsTransform : transforms) { | ||||
|             INDArray arr2 = bagofWordsTransform.transformFrom(tokens); | ||||
|             arr2 = arr2.reshape(arr2.length()); | ||||
|             NDArrayList newList = new NDArrayList(arr2,(int) arr2.length()); | ||||
|             ndArrayList.addAll(newList); | ||||
|         } | ||||
| 
 | ||||
|         return ndArrayList.array(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,226 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import lombok.EqualsAndHashCode; | ||||
| import org.datavec.api.transform.metadata.ColumnMetaData; | ||||
| import org.datavec.api.transform.metadata.NDArrayMetaData; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.transform.transform.BaseColumnTransform; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Counter; | ||||
| import org.nd4j.common.util.MathUtils; | ||||
| import org.nd4j.shade.jackson.annotation.JsonCreator; | ||||
| import org.nd4j.shade.jackson.annotation.JsonInclude; | ||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| @Data | ||||
| @EqualsAndHashCode(callSuper = true, exclude = {"tokenizerFactory"}) | ||||
| @JsonInclude(JsonInclude.Include.NON_NULL) | ||||
| public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform { | ||||
| 
 | ||||
|     private String newColumName; | ||||
|     private  Map<String,Integer> wordIndexMap; | ||||
|     private Map<String,Double> weightMap; | ||||
|     private boolean exceptionOnUnknown; | ||||
|     private String tokenizerFactoryClass; | ||||
|     private String preprocessorClass; | ||||
|     private TokenizerFactory tokenizerFactory; | ||||
| 
 | ||||
|     @JsonCreator | ||||
|     public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String columnName, | ||||
|                                                          @JsonProperty("newColumnName") String newColumnName, | ||||
|                                                          @JsonProperty("wordIndexMap") Map<String,Integer> wordIndexMap, | ||||
|                                                          @JsonProperty("idfMap") Map<String,Double> idfMap, | ||||
|                                                          @JsonProperty("exceptionOnUnknown") boolean exceptionOnUnknown, | ||||
|                                                          @JsonProperty("tokenizerFactoryClass") String tokenizerFactoryClass, | ||||
|                                                          @JsonProperty("preprocessorClass") String preprocessorClass) { | ||||
|         super(columnName); | ||||
|         this.newColumName = newColumnName; | ||||
|         this.wordIndexMap = wordIndexMap; | ||||
|         this.exceptionOnUnknown = exceptionOnUnknown; | ||||
|         this.weightMap = idfMap; | ||||
|         this.tokenizerFactoryClass = tokenizerFactoryClass; | ||||
|         this.preprocessorClass = preprocessorClass; | ||||
|         if(this.tokenizerFactoryClass == null) { | ||||
|             this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName(); | ||||
|         } | ||||
|         try { | ||||
|             tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance(); | ||||
|         } catch (Exception e) { | ||||
|             throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?"); | ||||
|         } | ||||
| 
 | ||||
|         if(preprocessorClass != null){ | ||||
|             try { | ||||
|                 TokenPreProcess tpp = (TokenPreProcess) Class.forName(this.preprocessorClass).newInstance(); | ||||
|                 tokenizerFactory.setTokenPreProcessor(tpp); | ||||
|             } catch (Exception e){ | ||||
|                 throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?"); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public List<Writable> map(List<Writable> writables) { | ||||
|         Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName)); | ||||
|         List<Writable> ret = new ArrayList<>(writables); | ||||
|         ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString()))); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Object map(Object input) { | ||||
|         return convert(input.toString()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Object mapSequence(Object sequence) { | ||||
|         return convert(sequence.toString()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Schema transform(Schema inputSchema) { | ||||
|         Schema.Builder newSchema = new Schema.Builder(); | ||||
|         for(int i = 0; i < inputSchema.numColumns(); i++) { | ||||
|             if(inputSchema.getName(i).equals(this.columnName)) { | ||||
|                 newSchema.addColumnNDArray(newColumName,new long[]{1,wordIndexMap.size()}); | ||||
|             } | ||||
|             else { | ||||
|                 newSchema.addColumn(inputSchema.getMetaData(i)); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return newSchema.build(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Convert the given text | ||||
|      * in to an {@link INDArray} | ||||
|      * using the {@link TokenizerFactory} | ||||
|      * specified in the constructor. | ||||
|      * @param text the text to transform | ||||
|      * @return the created {@link INDArray} | ||||
|      * based on the {@link #wordIndexMap} for the column indices | ||||
|      * of the word. | ||||
|      */ | ||||
|     public INDArray convert(String text) { | ||||
|         Tokenizer tokenizer = tokenizerFactory.create(text); | ||||
|         List<String> tokens = tokenizer.getTokens(); | ||||
|         INDArray create = Nd4j.create(1,wordIndexMap.size()); | ||||
|         Counter<String> tokenizedCounter = new Counter<>(); | ||||
| 
 | ||||
|         for(int i = 0; i < tokens.size(); i++) { | ||||
|             tokenizedCounter.incrementCount(tokens.get(i),1.0); | ||||
|         } | ||||
| 
 | ||||
|         for(int i = 0; i < tokens.size(); i++) { | ||||
|             if(wordIndexMap.containsKey(tokens.get(i))) { | ||||
|                 int idx = wordIndexMap.get(tokens.get(i)); | ||||
|                 int count = (int) tokenizedCounter.getCount(tokens.get(i)); | ||||
|                 double weight = tfidfWord(tokens.get(i),count,tokens.size()); | ||||
|                 create.putScalar(idx,weight); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return create; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Calculate the tifdf for a word | ||||
|      * given the word, word count, and document length | ||||
|      * @param word the word to calculate | ||||
|      * @param wordCount the word frequency | ||||
|      * @param documentLength the number of words in the document | ||||
|      * @return the tfidf weight for a given word | ||||
|      */ | ||||
|     public double tfidfWord(String word, long wordCount, long documentLength) { | ||||
|         double tf = tfForWord(wordCount, documentLength); | ||||
|         double idf = idfForWord(word); | ||||
|         return MathUtils.tfidf(tf, idf); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Calculate the weight term frequency for a given | ||||
|      * word normalized by the dcoument length | ||||
|      * @param wordCount the word frequency | ||||
|      * @param documentLength the number of words in the edocument | ||||
|      * @return | ||||
|      */ | ||||
|     private double tfForWord(long wordCount, long documentLength) { | ||||
|         return wordCount; | ||||
|     } | ||||
| 
 | ||||
|     private double idfForWord(String word) { | ||||
|         if(weightMap.containsKey(word)) | ||||
|             return weightMap.get(word); | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { | ||||
|         return new NDArrayMetaData(outputColumnName(),new long[]{1,wordIndexMap.size()}); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String outputColumnName() { | ||||
|         return newColumName; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String[] outputColumnNames() { | ||||
|         return new String[]{newColumName}; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String[] columnNames() { | ||||
|         return new String[]{columnName()}; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String columnName() { | ||||
|         return columnName; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Writable map(Writable columnWritable) { | ||||
|         return new NDArrayWritable(convert(columnWritable.toString())); | ||||
|     } | ||||
| } | ||||
| @ -1,107 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.uima; | ||||
| 
 | ||||
| import org.apache.uima.analysis_engine.AnalysisEngine; | ||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; | ||||
| import org.apache.uima.cas.CAS; | ||||
| import org.apache.uima.resource.ResourceInitializationException; | ||||
| import org.apache.uima.util.CasPool; | ||||
| 
 | ||||
| public class UimaResource { | ||||
| 
 | ||||
|     private AnalysisEngine analysisEngine; | ||||
|     private CasPool casPool; | ||||
| 
 | ||||
|     public UimaResource(AnalysisEngine analysisEngine) throws ResourceInitializationException { | ||||
|         this.analysisEngine = analysisEngine; | ||||
|         this.casPool = new CasPool(Runtime.getRuntime().availableProcessors() * 10, analysisEngine); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public UimaResource(AnalysisEngine analysisEngine, CasPool casPool) { | ||||
|         this.analysisEngine = analysisEngine; | ||||
|         this.casPool = casPool; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public AnalysisEngine getAnalysisEngine() { | ||||
|         return analysisEngine; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public void setAnalysisEngine(AnalysisEngine analysisEngine) { | ||||
|         this.analysisEngine = analysisEngine; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public CasPool getCasPool() { | ||||
|         return casPool; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public void setCasPool(CasPool casPool) { | ||||
|         this.casPool = casPool; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Use the given analysis engine and process the given text | ||||
|      * You must release the return cas yourself | ||||
|      * @param text the text to rpocess | ||||
|      * @return the processed cas | ||||
|      */ | ||||
|     public CAS process(String text) { | ||||
|         CAS cas = retrieve(); | ||||
| 
 | ||||
|         cas.setDocumentText(text); | ||||
|         try { | ||||
|             analysisEngine.process(cas); | ||||
|         } catch (AnalysisEngineProcessException e) { | ||||
|             if (text != null && !text.isEmpty()) | ||||
|                 return process(text); | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
| 
 | ||||
|         return cas; | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public CAS retrieve() { | ||||
|         CAS ret = casPool.getCas(); | ||||
|         try { | ||||
|             return ret == null ? analysisEngine.newCAS() : ret; | ||||
|         } catch (ResourceInitializationException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public void release(CAS cas) { | ||||
|         casPool.releaseCas(cas); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,77 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.vectorizer; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.Record; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; | ||||
| 
 | ||||
| import java.util.HashSet; | ||||
| import java.util.Set; | ||||
| 
 | ||||
| public abstract class AbstractTfidfVectorizer<VECTOR_TYPE> extends TextVectorizer<VECTOR_TYPE> { | ||||
| 
 | ||||
|     @Override | ||||
|     public void doWithTokens(Tokenizer tokenizer) { | ||||
|         Set<String> seen = new HashSet<>(); | ||||
|         while (tokenizer.hasMoreTokens()) { | ||||
|             String token = tokenizer.nextToken(); | ||||
|             if (!stopWords.contains(token)) { | ||||
|                 cache.incrementCount(token); | ||||
|                 if (!seen.contains(token)) { | ||||
|                     cache.incrementDocCount(token); | ||||
|                 } | ||||
|                 seen.add(token); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public TokenizerFactory createTokenizerFactory(Configuration conf) { | ||||
|         String clazz = conf.get(TOKENIZER, DefaultTokenizerFactory.class.getName()); | ||||
|         try { | ||||
|             Class<? extends TokenizerFactory> tokenizerFactoryClazz = | ||||
|                             (Class<? extends TokenizerFactory>) Class.forName(clazz); | ||||
|             TokenizerFactory tf = tokenizerFactoryClazz.newInstance(); | ||||
|             String preproc = conf.get(PREPROCESSOR, null); | ||||
|             if(preproc != null){ | ||||
|                 TokenPreProcess tpp = (TokenPreProcess) Class.forName(preproc).newInstance(); | ||||
|                 tf.setTokenPreProcessor(tpp); | ||||
|             } | ||||
|             return tf; | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public abstract VECTOR_TYPE createVector(Object[] args); | ||||
| 
 | ||||
|     @Override | ||||
|     public abstract VECTOR_TYPE fitTransform(RecordReader reader); | ||||
| 
 | ||||
|     @Override | ||||
|     public abstract VECTOR_TYPE transform(Record record); | ||||
| } | ||||
| @ -1,121 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.vectorizer; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import org.nd4j.common.primitives.Counter; | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.Record; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.vector.Vectorizer; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.nlp.metadata.DefaultVocabCache; | ||||
| import org.datavec.nlp.metadata.VocabCache; | ||||
| import org.datavec.nlp.stopwords.StopWords; | ||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; | ||||
| 
 | ||||
| import java.util.Collection; | ||||
| 
 | ||||
| public abstract class TextVectorizer<VECTOR_TYPE> implements Vectorizer<VECTOR_TYPE> { | ||||
| 
 | ||||
|     protected TokenizerFactory tokenizerFactory; | ||||
|     protected int minWordFrequency = 0; | ||||
|     public final static String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency"; | ||||
|     public final static String STOP_WORDS = "org.nd4j.nlp.stopwords"; | ||||
|     public final static String TOKENIZER = "org.datavec.nlp.tokenizerfactory"; | ||||
|     public static final String PREPROCESSOR = "org.datavec.nlp.preprocessor"; | ||||
|     public final static String VOCAB_CACHE = "org.datavec.nlp.vocabcache"; | ||||
|     protected Collection<String> stopWords; | ||||
|     @Getter | ||||
|     protected VocabCache cache; | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(Configuration conf) { | ||||
|         tokenizerFactory = createTokenizerFactory(conf); | ||||
|         minWordFrequency = conf.getInt(MIN_WORD_FREQUENCY, 5); | ||||
|         if(conf.get(STOP_WORDS) != null) | ||||
|             stopWords = conf.getStringCollection(STOP_WORDS); | ||||
|         if (stopWords == null) | ||||
|             stopWords = StopWords.getStopWords(); | ||||
| 
 | ||||
|         String clazz = conf.get(VOCAB_CACHE, DefaultVocabCache.class.getName()); | ||||
|         try { | ||||
|             Class<? extends VocabCache> tokenizerFactoryClazz = (Class<? extends VocabCache>) Class.forName(clazz); | ||||
|             cache = tokenizerFactoryClazz.newInstance(); | ||||
|             cache.initialize(conf); | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void fit(RecordReader reader) { | ||||
|         fit(reader, null); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void fit(RecordReader reader, RecordCallBack callBack) { | ||||
|         while (reader.hasNext()) { | ||||
|             Record record = reader.nextRecord(); | ||||
|             String s = toString(record.getRecord()); | ||||
|             Tokenizer tokenizer = tokenizerFactory.create(s); | ||||
|             doWithTokens(tokenizer); | ||||
|             if (callBack != null) | ||||
|                 callBack.onRecord(record); | ||||
|             cache.incrementNumDocs(1); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     protected Counter<String> wordFrequenciesForRecord(Collection<Writable> record) { | ||||
|         String s = toString(record); | ||||
|         Tokenizer tokenizer = tokenizerFactory.create(s); | ||||
|         Counter<String> ret = new Counter<>(); | ||||
|         while (tokenizer.hasMoreTokens()) | ||||
|             ret.incrementCount(tokenizer.nextToken(), 1.0); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     protected String toString(Collection<Writable> record) { | ||||
|         StringBuilder sb = new StringBuilder(); | ||||
|         for(Writable w : record){ | ||||
|             sb.append(w.toString()); | ||||
|         } | ||||
|         return sb.toString(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Increment counts, add to collection,... | ||||
|      * @param tokenizer | ||||
|      */ | ||||
|     public abstract void doWithTokens(Tokenizer tokenizer); | ||||
| 
 | ||||
|     /** | ||||
|      * Create tokenizer factory based on the configuration | ||||
|      * @param conf the configuration to use | ||||
|      * @return the tokenizer factory based on the configuration | ||||
|      */ | ||||
|     public abstract TokenizerFactory createTokenizerFactory(Configuration conf); | ||||
| 
 | ||||
| } | ||||
| @ -1,105 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.vectorizer; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.nd4j.common.primitives.Counter; | ||||
| import org.datavec.api.records.Record; | ||||
| import org.datavec.api.records.metadata.RecordMetaDataURI; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class TfidfVectorizer extends AbstractTfidfVectorizer<INDArray> { | ||||
|     /** | ||||
|      * Default: True.<br> | ||||
|      * If true: use idf(d, t) = log [ (1 + n) / (1 + df(d, t)) ] + 1<br> | ||||
|      * If false: use idf(t) = log [ n / df(t) ] + 1<br> | ||||
|      */ | ||||
|     public static final String SMOOTH_IDF = "org.datavec.nlp.TfidfVectorizer.smooth_idf"; | ||||
| 
 | ||||
|     protected boolean smooth_idf; | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray createVector(Object[] args) { | ||||
|         Counter<String> docFrequencies = (Counter<String>) args[0]; | ||||
|         double[] vector = new double[cache.vocabWords().size()]; | ||||
|         for (int i = 0; i < cache.vocabWords().size(); i++) { | ||||
|             String word = cache.wordAt(i); | ||||
|             double freq = docFrequencies.getCount(word); | ||||
|             vector[i] = cache.tfidf(word, freq, smooth_idf); | ||||
|         } | ||||
|         return Nd4j.create(vector); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray fitTransform(RecordReader reader) { | ||||
|         return fitTransform(reader, null); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray fitTransform(final RecordReader reader, RecordCallBack callBack) { | ||||
|         final List<Record> records = new ArrayList<>(); | ||||
|         fit(reader, new RecordCallBack() { | ||||
|             @Override | ||||
|             public void onRecord(Record record) { | ||||
|                 records.add(record); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         if (records.isEmpty()) | ||||
|             throw new IllegalStateException("No records found!"); | ||||
|         INDArray ret = Nd4j.create(records.size(), cache.vocabWords().size()); | ||||
|         int i = 0; | ||||
|         for (Record record : records) { | ||||
|             INDArray transformed = transform(record); | ||||
|             org.datavec.api.records.impl.Record transformedRecord = new org.datavec.api.records.impl.Record( | ||||
|                             Arrays.asList(new NDArrayWritable(transformed), | ||||
|                                             record.getRecord().get(record.getRecord().size() - 1)), | ||||
|                             new RecordMetaDataURI(record.getMetaData().getURI(), reader.getClass())); | ||||
|             ret.putRow(i++, transformed); | ||||
|             if (callBack != null) { | ||||
|                 callBack.onRecord(transformedRecord); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray transform(Record record) { | ||||
|         Counter<String> wordFrequencies = wordFrequenciesForRecord(record.getRecord()); | ||||
|         return createVector(new Object[] {wordFrequencies}); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void initialize(Configuration conf){ | ||||
|         super.initialize(conf); | ||||
|         this.smooth_idf = conf.getBoolean(SMOOTH_IDF, true); | ||||
|     } | ||||
| } | ||||
| @ -1,194 +0,0 @@ | ||||
| a | ||||
| ----s | ||||
| act | ||||
| "the | ||||
| "The | ||||
| about | ||||
| above | ||||
| after | ||||
| again | ||||
| against | ||||
| all | ||||
| am | ||||
| an | ||||
| and | ||||
| any | ||||
| are | ||||
| aren't | ||||
| as | ||||
| at | ||||
| be | ||||
| because | ||||
| been | ||||
| before | ||||
| being | ||||
| below | ||||
| between | ||||
| both | ||||
| but | ||||
| by | ||||
| can't | ||||
| cannot | ||||
| could | ||||
| couldn't | ||||
| did | ||||
| didn't | ||||
| do | ||||
| does | ||||
| doesn't | ||||
| doing | ||||
| don't | ||||
| down | ||||
| during | ||||
| each | ||||
| few | ||||
| for | ||||
| from | ||||
| further | ||||
| had | ||||
| hadn't | ||||
| has | ||||
| hasn't | ||||
| have | ||||
| haven't | ||||
| having | ||||
| he | ||||
| he'd | ||||
| he'll | ||||
| he's | ||||
| her | ||||
| here | ||||
| here's | ||||
| hers | ||||
| herself | ||||
| him | ||||
| himself | ||||
| his | ||||
| how | ||||
| how's | ||||
| i | ||||
| i'd | ||||
| i'll | ||||
| i'm | ||||
| i've | ||||
| if | ||||
| in | ||||
| into | ||||
| is | ||||
| isn't | ||||
| it | ||||
| it's | ||||
| its | ||||
| itself | ||||
| let's | ||||
| me | ||||
| more | ||||
| most | ||||
| mustn't | ||||
| my | ||||
| myself | ||||
| no | ||||
| nor | ||||
| not | ||||
| of | ||||
| off | ||||
| on | ||||
| once | ||||
| only | ||||
| or | ||||
| other | ||||
| ought | ||||
| our | ||||
| ours  | ||||
| ourselves | ||||
| out | ||||
| over | ||||
| own | ||||
| put | ||||
| same | ||||
| shan't | ||||
| she | ||||
| she'd | ||||
| she'll | ||||
| she's | ||||
| should | ||||
| somebody | ||||
| something | ||||
| shouldn't | ||||
| so | ||||
| some | ||||
| such | ||||
| take | ||||
| than | ||||
| that | ||||
| that's | ||||
| the | ||||
| their | ||||
| theirs | ||||
| them | ||||
| themselves | ||||
| then | ||||
| there | ||||
| there's | ||||
| these | ||||
| they | ||||
| they'd | ||||
| they'll | ||||
| they're | ||||
| they've | ||||
| this | ||||
| those | ||||
| through | ||||
| to | ||||
| too | ||||
| under | ||||
| until | ||||
| up | ||||
| very | ||||
| was | ||||
| wasn't | ||||
| we | ||||
| we'd | ||||
| we'll | ||||
| we're | ||||
| we've | ||||
| were | ||||
| weren't | ||||
| what | ||||
| what's | ||||
| when | ||||
| when's | ||||
| where | ||||
| where's | ||||
| which | ||||
| while | ||||
| who | ||||
| who's | ||||
| whom | ||||
| why | ||||
| why's | ||||
| will | ||||
| with | ||||
| without | ||||
| won't | ||||
| would | ||||
| wouldn't | ||||
| you | ||||
| you'd | ||||
| you'll | ||||
| you're | ||||
| you've | ||||
| your | ||||
| yours | ||||
| yourself | ||||
| yourselves | ||||
| . | ||||
| ? | ||||
| ! | ||||
| , | ||||
| + | ||||
| = | ||||
| also | ||||
| - | ||||
| ; | ||||
| : | ||||
| @ -1,46 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| package org.datavec.nlp; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Set<Class<?>> getExclusions() { | ||||
| 		//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) | ||||
| 		return new HashSet<>(); | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected String getPackageName() { | ||||
| 		return "org.datavec.nlp"; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Class<?> getBaseClass() { | ||||
| 		return BaseND4JTest.class; | ||||
| 	} | ||||
| } | ||||
| @ -1,130 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.reader; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.Record; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.split.CollectionInputSplit; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.nlp.vectorizer.TfidfVectorizer; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.net.URI; | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.Assert.*; | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class TfidfRecordReaderTest { | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     @Test | ||||
|     public void testReader() throws Exception { | ||||
|         TfidfVectorizer vectorizer = new TfidfVectorizer(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); | ||||
|         conf.setBoolean(RecordReader.APPEND_LABEL, true); | ||||
|         vectorizer.initialize(conf); | ||||
|         TfidfRecordReader reader = new TfidfRecordReader(); | ||||
|         File f = testDir.newFolder(); | ||||
|         new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); | ||||
|         List<URI> u = new ArrayList<>(); | ||||
|         for(File f2 : f.listFiles()){ | ||||
|             if(f2.isDirectory()){ | ||||
|                 for(File f3 : f2.listFiles()){ | ||||
|                     u.add(f3.toURI()); | ||||
|                 } | ||||
|             } else { | ||||
|                 u.add(f2.toURI()); | ||||
|             } | ||||
|         } | ||||
|         Collections.sort(u); | ||||
|         CollectionInputSplit c = new CollectionInputSplit(u); | ||||
|         reader.initialize(conf, c); | ||||
|         int count = 0; | ||||
|         int[] labelAssertions = new int[3]; | ||||
|         while (reader.hasNext()) { | ||||
|             Collection<Writable> record = reader.next(); | ||||
|             Iterator<Writable> recordIter = record.iterator(); | ||||
|             NDArrayWritable writable = (NDArrayWritable) recordIter.next(); | ||||
|             labelAssertions[count] = recordIter.next().toInt(); | ||||
|             count++; | ||||
|         } | ||||
| 
 | ||||
|         assertArrayEquals(new int[] {0, 1, 2}, labelAssertions); | ||||
|         assertEquals(3, reader.getLabels().size()); | ||||
|         assertEquals(3, count); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRecordMetaData() throws Exception { | ||||
|         TfidfVectorizer vectorizer = new TfidfVectorizer(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); | ||||
|         conf.setBoolean(RecordReader.APPEND_LABEL, true); | ||||
|         vectorizer.initialize(conf); | ||||
|         TfidfRecordReader reader = new TfidfRecordReader(); | ||||
|         File f = testDir.newFolder(); | ||||
|         new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); | ||||
|         reader.initialize(conf, new FileSplit(f)); | ||||
| 
 | ||||
|         while (reader.hasNext()) { | ||||
|             Record record = reader.nextRecord(); | ||||
|             assertNotNull(record.getMetaData().getURI()); | ||||
|             assertEquals(record.getMetaData().getReaderClass(), TfidfRecordReader.class); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testReadRecordFromMetaData() throws Exception { | ||||
|         TfidfVectorizer vectorizer = new TfidfVectorizer(); | ||||
|         Configuration conf = new Configuration(); | ||||
|         conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); | ||||
|         conf.setBoolean(RecordReader.APPEND_LABEL, true); | ||||
|         vectorizer.initialize(conf); | ||||
|         TfidfRecordReader reader = new TfidfRecordReader(); | ||||
|         File f = testDir.newFolder(); | ||||
|         new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); | ||||
|         reader.initialize(conf, new FileSplit(f)); | ||||
| 
 | ||||
|         Record record = reader.nextRecord(); | ||||
| 
 | ||||
|         Record reread = reader.loadFromMetaData(record.getMetaData()); | ||||
| 
 | ||||
|         assertEquals(record.getRecord().size(), 2); | ||||
|         assertEquals(reread.getRecord().size(), 2); | ||||
|         assertEquals(record.getRecord().get(0), reread.getRecord().get(0)); | ||||
|         assertEquals(record.getRecord().get(1), reread.getRecord().get(1)); | ||||
|         assertEquals(record.getMetaData(), reread.getMetaData()); | ||||
|     } | ||||
| } | ||||
| @ -1,96 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.SequenceSchema; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class TestGazeteerTransform { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testGazeteerTransform(){ | ||||
| 
 | ||||
|         String[] corpus = { | ||||
|                 "hello I like apple".toLowerCase(), | ||||
|                 "cherry date eggplant potato".toLowerCase() | ||||
|         }; | ||||
| 
 | ||||
|         //Gazeteer transform: basically 0/1 if word is present. Assumes already tokenized input | ||||
|         List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant"); | ||||
| 
 | ||||
|         GazeteerTransform t = new GazeteerTransform("words", "out", words); | ||||
| 
 | ||||
|         SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder() | ||||
|                 .addColumnString("words").build(); | ||||
| 
 | ||||
|         TransformProcess tp = new TransformProcess.Builder(schema) | ||||
|                 .transform(t) | ||||
|                 .build(); | ||||
| 
 | ||||
|         List<List<List<Writable>>> input = new ArrayList<>(); | ||||
|         for(String s : corpus){ | ||||
|             String[] split = s.split(" "); | ||||
|             List<List<Writable>> seq = new ArrayList<>(); | ||||
|             for(String s2 : split){ | ||||
|                 seq.add(Collections.<Writable>singletonList(new Text(s2))); | ||||
|             } | ||||
|             input.add(seq); | ||||
|         } | ||||
| 
 | ||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp); | ||||
| 
 | ||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); | ||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0}); | ||||
|         INDArray exp1 = Nd4j.create(new float[]{0, 0, 1, 1, 1}); | ||||
| 
 | ||||
|         assertEquals(exp0, arr0); | ||||
|         assertEquals(exp1, arr1); | ||||
| 
 | ||||
| 
 | ||||
|         String json = tp.toJson(); | ||||
|         TransformProcess tp2 = TransformProcess.fromJson(json); | ||||
|         assertEquals(tp, tp2); | ||||
| 
 | ||||
|         List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp); | ||||
|         INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); | ||||
|         INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         assertEquals(exp0, arr0a); | ||||
|         assertEquals(exp1, arr1a); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,96 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.SequenceSchema; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class TestMultiNLPTransform { | ||||
| 
 | ||||
|     @Test | ||||
|     public void test(){ | ||||
| 
 | ||||
|         List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant"); | ||||
|         GazeteerTransform t1 = new GazeteerTransform("words", "out", words); | ||||
|         GazeteerTransform t2 = new GazeteerTransform("out", "out", words); | ||||
| 
 | ||||
| 
 | ||||
|         MultiNlpTransform multi = new MultiNlpTransform("text", new BagOfWordsTransform[]{t1, t2}, "out"); | ||||
| 
 | ||||
|         String[] corpus = { | ||||
|                 "hello I like apple".toLowerCase(), | ||||
|                 "date eggplant potato".toLowerCase() | ||||
|         }; | ||||
| 
 | ||||
|         List<List<List<Writable>>> input = new ArrayList<>(); | ||||
|         for(String s : corpus){ | ||||
|             String[] split = s.split(" "); | ||||
|             List<List<Writable>> seq = new ArrayList<>(); | ||||
|             for(String s2 : split){ | ||||
|                 seq.add(Collections.<Writable>singletonList(new Text(s2))); | ||||
|             } | ||||
|             input.add(seq); | ||||
|         } | ||||
| 
 | ||||
|         SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder() | ||||
|                 .addColumnString("text").build(); | ||||
| 
 | ||||
|         TransformProcess tp = new TransformProcess.Builder(schema) | ||||
|                 .transform(multi) | ||||
|                 .build(); | ||||
| 
 | ||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp); | ||||
| 
 | ||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); | ||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0, 1, 0, 0, 0, 0}); | ||||
|         INDArray exp1 = Nd4j.create(new float[]{0, 0, 0, 1, 1, 0, 0, 0, 1, 1}); | ||||
| 
 | ||||
|         assertEquals(exp0, arr0); | ||||
|         assertEquals(exp1, arr1); | ||||
| 
 | ||||
| 
 | ||||
|         String json = tp.toJson(); | ||||
|         TransformProcess tp2 = TransformProcess.fromJson(json); | ||||
|         assertEquals(tp, tp2); | ||||
| 
 | ||||
|         List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp); | ||||
|         INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); | ||||
|         INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         assertEquals(exp0, arr0a); | ||||
|         assertEquals(exp1, arr1a); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,414 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.nlp.transforms; | ||||
| 
 | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.SequenceSchema; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.datavec.nlp.metadata.VocabCache; | ||||
| import org.datavec.nlp.tokenization.tokenizer.preprocessor.LowerCasePreProcessor; | ||||
| import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||||
| import org.datavec.nlp.vectorizer.TfidfVectorizer; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Triple; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.datavec.nlp.vectorizer.TextVectorizer.*; | ||||
| import static org.junit.Assert.assertArrayEquals; | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class TokenizerBagOfWordsTermSequenceIndexTransformTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSequenceExecution() { | ||||
|         //credit: https://stackoverflow.com/questions/23792781/tf-idf-feature-weights-using-sklearn-feature-extraction-text-tfidfvectorizer | ||||
|         String[] corpus = { | ||||
|                 "This is very strange".toLowerCase(), | ||||
|                 "This is very nice".toLowerCase() | ||||
|         }; | ||||
|         //{'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0} | ||||
| 
 | ||||
|         /* | ||||
|          ## Reproduce with: | ||||
|          from sklearn.feature_extraction.text import TfidfVectorizer | ||||
|          corpus = ["This is very strange", "This is very nice"] | ||||
| 
 | ||||
|          ## SMOOTH = FALSE case: | ||||
|          vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False) | ||||
|          X = vectorizer.fit_transform(corpus) | ||||
|          idf = vectorizer.idf_ | ||||
|          print(dict(zip(vectorizer.get_feature_names(), idf))) | ||||
| 
 | ||||
|          newText = ["This is very strange", "This is very nice"] | ||||
|          out = vectorizer.transform(newText) | ||||
|          print(out) | ||||
| 
 | ||||
|          {'is': 1.0, 'nice': 1.6931471805599454, 'strange': 1.6931471805599454, 'this': 1.0, 'very': 1.0} | ||||
|          (0, 4)	1.0 | ||||
|          (0, 3)	1.0 | ||||
|          (0, 2)	1.6931471805599454 | ||||
|          (0, 0)	1.0 | ||||
|          (1, 4)	1.0 | ||||
|          (1, 3)	1.0 | ||||
|          (1, 1)	1.6931471805599454 | ||||
|          (1, 0)	1.0 | ||||
| 
 | ||||
|          ## SMOOTH + TRUE case: | ||||
|          {'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0} | ||||
|           (0, 4)	1.0 | ||||
|           (0, 3)	1.0 | ||||
|           (0, 2)	1.4054651081081644 | ||||
|           (0, 0)	1.0 | ||||
|           (1, 4)	1.0 | ||||
|           (1, 3)	1.0 | ||||
|           (1, 1)	1.4054651081081644 | ||||
|           (1, 0)	1.0 | ||||
|          */ | ||||
| 
 | ||||
|         List<List<List<Writable>>> input = new ArrayList<>(); | ||||
|         input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1])))); | ||||
| 
 | ||||
|         // First: Check TfidfVectorizer vs. scikit: | ||||
| 
 | ||||
|         Map<String,Double> idfMapNoSmooth = new HashMap<>(); | ||||
|         idfMapNoSmooth.put("is",1.0); | ||||
|         idfMapNoSmooth.put("nice",1.6931471805599454); | ||||
|         idfMapNoSmooth.put("strange",1.6931471805599454); | ||||
|         idfMapNoSmooth.put("this",1.0); | ||||
|         idfMapNoSmooth.put("very",1.0); | ||||
| 
 | ||||
|         Map<String,Double> idfMapSmooth = new HashMap<>(); | ||||
|         idfMapSmooth.put("is",1.0); | ||||
|         idfMapSmooth.put("nice",1.4054651081081644); | ||||
|         idfMapSmooth.put("strange",1.4054651081081644); | ||||
|         idfMapSmooth.put("this",1.0); | ||||
|         idfMapSmooth.put("very",1.0); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         TfidfVectorizer tfidfVectorizer = new TfidfVectorizer(); | ||||
|         Configuration configuration = new Configuration(); | ||||
|         configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); | ||||
|         configuration.set(MIN_WORD_FREQUENCY,"1"); | ||||
|         configuration.set(STOP_WORDS,""); | ||||
|         configuration.set(TfidfVectorizer.SMOOTH_IDF, "false"); | ||||
| 
 | ||||
|         tfidfVectorizer.initialize(configuration); | ||||
| 
 | ||||
|         CollectionRecordReader collectionRecordReader = new CollectionRecordReader(input.get(0)); | ||||
|         INDArray array = tfidfVectorizer.fitTransform(collectionRecordReader); | ||||
| 
 | ||||
|         INDArray expNoSmooth = Nd4j.create(DataType.FLOAT, 2, 5); | ||||
|         VocabCache vc = tfidfVectorizer.getCache(); | ||||
|         expNoSmooth.putScalar(0, vc.wordIndex("very"), 1.0); | ||||
|         expNoSmooth.putScalar(0, vc.wordIndex("this"), 1.0); | ||||
|         expNoSmooth.putScalar(0, vc.wordIndex("strange"), 1.6931471805599454); | ||||
|         expNoSmooth.putScalar(0, vc.wordIndex("is"), 1.0); | ||||
| 
 | ||||
|         expNoSmooth.putScalar(1, vc.wordIndex("very"), 1.0); | ||||
|         expNoSmooth.putScalar(1, vc.wordIndex("this"), 1.0); | ||||
|         expNoSmooth.putScalar(1, vc.wordIndex("nice"), 1.6931471805599454); | ||||
|         expNoSmooth.putScalar(1, vc.wordIndex("is"), 1.0); | ||||
| 
 | ||||
|         assertEquals(expNoSmooth, array); | ||||
| 
 | ||||
| 
 | ||||
|         //------------------------------------------------------------ | ||||
|         //Smooth version: | ||||
|         tfidfVectorizer = new TfidfVectorizer(); | ||||
|         configuration = new Configuration(); | ||||
|         configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); | ||||
|         configuration.set(MIN_WORD_FREQUENCY,"1"); | ||||
|         configuration.set(STOP_WORDS,""); | ||||
|         configuration.set(TfidfVectorizer.SMOOTH_IDF, "true"); | ||||
| 
 | ||||
|         tfidfVectorizer.initialize(configuration); | ||||
| 
 | ||||
|         collectionRecordReader.reset(); | ||||
|         array = tfidfVectorizer.fitTransform(collectionRecordReader); | ||||
| 
 | ||||
|         INDArray expSmooth = Nd4j.create(DataType.FLOAT, 2, 5); | ||||
|         expSmooth.putScalar(0, vc.wordIndex("very"), 1.0); | ||||
|         expSmooth.putScalar(0, vc.wordIndex("this"), 1.0); | ||||
|         expSmooth.putScalar(0, vc.wordIndex("strange"), 1.4054651081081644); | ||||
|         expSmooth.putScalar(0, vc.wordIndex("is"), 1.0); | ||||
| 
 | ||||
|         expSmooth.putScalar(1, vc.wordIndex("very"), 1.0); | ||||
|         expSmooth.putScalar(1, vc.wordIndex("this"), 1.0); | ||||
|         expSmooth.putScalar(1, vc.wordIndex("nice"), 1.4054651081081644); | ||||
|         expSmooth.putScalar(1, vc.wordIndex("is"), 1.0); | ||||
| 
 | ||||
|         assertEquals(expSmooth, array); | ||||
| 
 | ||||
| 
 | ||||
|         ////////////////////////////////////////////////////////// | ||||
| 
 | ||||
|         //Second: Check transform vs scikit/TfidfVectorizer | ||||
| 
 | ||||
|         List<String> vocab = new ArrayList<>(5);    //Arrays.asList("is","nice","strange","this","very"); | ||||
|         for( int i=0; i<5; i++ ){ | ||||
|             vocab.add(vc.wordAt(i)); | ||||
|         } | ||||
| 
 | ||||
|         String inputColumnName = "input"; | ||||
|         String outputColumnName = "output"; | ||||
|         Map<String,Integer> wordIndexMap = new HashMap<>(); | ||||
|         for(int i = 0; i < vocab.size(); i++) { | ||||
|             wordIndexMap.put(vocab.get(i),i); | ||||
|         } | ||||
| 
 | ||||
|         TokenizerBagOfWordsTermSequenceIndexTransform tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform( | ||||
|                 inputColumnName, | ||||
|                 outputColumnName, | ||||
|                 wordIndexMap, | ||||
|                 idfMapNoSmooth, | ||||
|                 false, | ||||
|                 null, null); | ||||
| 
 | ||||
|         SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder(); | ||||
|         sequenceSchemaBuilder.addColumnString("input"); | ||||
|         SequenceSchema schema = sequenceSchemaBuilder.build(); | ||||
|         assertEquals("input",schema.getName(0)); | ||||
| 
 | ||||
|         TransformProcess transformProcess = new TransformProcess.Builder(schema) | ||||
|                 .transform(tokenizerBagOfWordsTermSequenceIndexTransform) | ||||
|                 .build(); | ||||
| 
 | ||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         //System.out.println(execute); | ||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); | ||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         assertEquals(expNoSmooth.getRow(0, true), arr0); | ||||
|         assertEquals(expNoSmooth.getRow(1, true), arr1); | ||||
| 
 | ||||
| 
 | ||||
|         //-------------------------------- | ||||
|         //Check smooth: | ||||
| 
 | ||||
|         tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform( | ||||
|                 inputColumnName, | ||||
|                 outputColumnName, | ||||
|                 wordIndexMap, | ||||
|                 idfMapSmooth, | ||||
|                 false, | ||||
|                 null, null); | ||||
| 
 | ||||
|         schema = (SequenceSchema) new SequenceSchema.Builder().addColumnString("input").build(); | ||||
| 
 | ||||
|         transformProcess = new TransformProcess.Builder(schema) | ||||
|                 .transform(tokenizerBagOfWordsTermSequenceIndexTransform) | ||||
|                 .build(); | ||||
| 
 | ||||
|         execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); | ||||
| 
 | ||||
|         arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); | ||||
|         arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         assertEquals(expSmooth.getRow(0, true), arr0); | ||||
|         assertEquals(expSmooth.getRow(1, true), arr1); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         //Test JSON serialization: | ||||
| 
 | ||||
|         String json = transformProcess.toJson(); | ||||
|         TransformProcess fromJson = TransformProcess.fromJson(json); | ||||
|         assertEquals(transformProcess, fromJson); | ||||
|         List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, fromJson); | ||||
| 
 | ||||
|         INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); | ||||
|         INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         assertEquals(expSmooth.getRow(0, true), arr0a); | ||||
|         assertEquals(expSmooth.getRow(1, true), arr1a); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void additionalTest(){ | ||||
|         /* | ||||
|         ## To reproduce: | ||||
|         from sklearn.feature_extraction.text import TfidfVectorizer | ||||
|         corpus = [ | ||||
|            'This is the first document', | ||||
|            'This document is the second document', | ||||
|            'And this is the third one', | ||||
|            'Is this the first document', | ||||
|         ] | ||||
|         vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False) | ||||
|         X = vectorizer.fit_transform(corpus) | ||||
|         print(vectorizer.get_feature_names()) | ||||
| 
 | ||||
|         out = vectorizer.transform(corpus) | ||||
|         print(out) | ||||
| 
 | ||||
|         ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] | ||||
|           (0, 8)	1.0 | ||||
|           (0, 6)	1.0 | ||||
|           (0, 3)	1.0 | ||||
|           (0, 2)	1.6931471805599454 | ||||
|           (0, 1)	1.2876820724517808 | ||||
|           (1, 8)	1.0 | ||||
|           (1, 6)	1.0 | ||||
|           (1, 5)	2.386294361119891 | ||||
|           (1, 3)	1.0 | ||||
|           (1, 1)	2.5753641449035616 | ||||
|           (2, 8)	1.0 | ||||
|           (2, 7)	2.386294361119891 | ||||
|           (2, 6)	1.0 | ||||
|           (2, 4)	2.386294361119891 | ||||
|           (2, 3)	1.0 | ||||
|           (2, 0)	2.386294361119891 | ||||
|           (3, 8)	1.0 | ||||
|           (3, 6)	1.0 | ||||
|           (3, 3)	1.0 | ||||
|           (3, 2)	1.6931471805599454 | ||||
|           (3, 1)	1.2876820724517808 | ||||
|           {'and': 2.386294361119891, 'document': 1.2876820724517808, 'first': 1.6931471805599454, 'is': 1.0, 'one': 2.386294361119891, 'second': 2.386294361119891, 'the': 1.0, 'third': 2.386294361119891, 'this': 1.0} | ||||
|          */ | ||||
| 
 | ||||
|         String[] corpus = { | ||||
|                 "This is the first document", | ||||
|                 "This document is the second document", | ||||
|                 "And this is the third one", | ||||
|                 "Is this the first document"}; | ||||
| 
 | ||||
|         TfidfVectorizer tfidfVectorizer = new TfidfVectorizer(); | ||||
|         Configuration configuration = new Configuration(); | ||||
|         configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); | ||||
|         configuration.set(MIN_WORD_FREQUENCY,"1"); | ||||
|         configuration.set(STOP_WORDS,""); | ||||
|         configuration.set(TfidfVectorizer.SMOOTH_IDF, "false"); | ||||
|         configuration.set(PREPROCESSOR, LowerCasePreProcessor.class.getName()); | ||||
| 
 | ||||
|         tfidfVectorizer.initialize(configuration); | ||||
| 
 | ||||
|         List<List<List<Writable>>> input = new ArrayList<>(); | ||||
|         //input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1])))); | ||||
|         List<List<Writable>> seq = new ArrayList<>(); | ||||
|         for(String s : corpus){ | ||||
|             seq.add(Collections.<Writable>singletonList(new Text(s))); | ||||
|         } | ||||
|         input.add(seq); | ||||
| 
 | ||||
|         CollectionRecordReader crr = new CollectionRecordReader(seq); | ||||
|         INDArray arr = tfidfVectorizer.fitTransform(crr); | ||||
| 
 | ||||
|         //System.out.println(arr); | ||||
|         assertArrayEquals(new long[]{4, 9}, arr.shape()); | ||||
| 
 | ||||
|         List<String> pyVocab = Arrays.asList("and", "document", "first", "is", "one", "second", "the", "third", "this"); | ||||
|         List<Triple<Integer,Integer,Double>> l = new ArrayList<>(); | ||||
|         l.add(new Triple<>(0, 8, 1.0)); | ||||
|         l.add(new Triple<>(0, 6, 1.0)); | ||||
|         l.add(new Triple<>(0, 3, 1.0)); | ||||
|         l.add(new Triple<>(0, 2, 1.6931471805599454)); | ||||
|         l.add(new Triple<>(0, 1, 1.2876820724517808)); | ||||
|         l.add(new Triple<>(1, 8, 1.0)); | ||||
|         l.add(new Triple<>(1, 6, 1.0)); | ||||
|         l.add(new Triple<>(1, 5, 2.386294361119891)); | ||||
|         l.add(new Triple<>(1, 3, 1.0)); | ||||
|         l.add(new Triple<>(1, 1, 2.5753641449035616)); | ||||
|         l.add(new Triple<>(2, 8, 1.0)); | ||||
|         l.add(new Triple<>(2, 7, 2.386294361119891)); | ||||
|         l.add(new Triple<>(2, 6, 1.0)); | ||||
|         l.add(new Triple<>(2, 4, 2.386294361119891)); | ||||
|         l.add(new Triple<>(2, 3, 1.0)); | ||||
|         l.add(new Triple<>(2, 0, 2.386294361119891)); | ||||
|         l.add(new Triple<>(3, 8, 1.0)); | ||||
|         l.add(new Triple<>(3, 6, 1.0)); | ||||
|         l.add(new Triple<>(3, 3, 1.0)); | ||||
|         l.add(new Triple<>(3, 2, 1.6931471805599454)); | ||||
|         l.add(new Triple<>(3, 1, 1.2876820724517808)); | ||||
| 
 | ||||
|         INDArray exp = Nd4j.create(DataType.FLOAT, 4, 9); | ||||
|         for(Triple<Integer,Integer,Double> t : l){ | ||||
|             //Work out work index, accounting for different vocab/word orders: | ||||
|             int wIdx = tfidfVectorizer.getCache().wordIndex(pyVocab.get(t.getSecond())); | ||||
|             exp.putScalar(t.getFirst(), wIdx, t.getThird()); | ||||
|         } | ||||
| 
 | ||||
|         assertEquals(exp, arr); | ||||
| 
 | ||||
| 
 | ||||
|         Map<String,Double> idfWeights = new HashMap<>(); | ||||
|         idfWeights.put("and", 2.386294361119891); | ||||
|         idfWeights.put("document", 1.2876820724517808); | ||||
|         idfWeights.put("first", 1.6931471805599454); | ||||
|         idfWeights.put("is", 1.0); | ||||
|         idfWeights.put("one", 2.386294361119891); | ||||
|         idfWeights.put("second", 2.386294361119891); | ||||
|         idfWeights.put("the", 1.0); | ||||
|         idfWeights.put("third", 2.386294361119891); | ||||
|         idfWeights.put("this", 1.0); | ||||
| 
 | ||||
| 
 | ||||
|         List<String> vocab = new ArrayList<>(9);    //Arrays.asList("is","nice","strange","this","very"); | ||||
|         for( int i=0; i<9; i++ ){ | ||||
|             vocab.add(tfidfVectorizer.getCache().wordAt(i)); | ||||
|         } | ||||
| 
 | ||||
|         String inputColumnName = "input"; | ||||
|         String outputColumnName = "output"; | ||||
|         Map<String,Integer> wordIndexMap = new HashMap<>(); | ||||
|         for(int i = 0; i < vocab.size(); i++) { | ||||
|             wordIndexMap.put(vocab.get(i),i); | ||||
|         } | ||||
| 
 | ||||
|         TokenizerBagOfWordsTermSequenceIndexTransform transform = new TokenizerBagOfWordsTermSequenceIndexTransform( | ||||
|                 inputColumnName, | ||||
|                 outputColumnName, | ||||
|                 wordIndexMap, | ||||
|                 idfWeights, | ||||
|                 false, | ||||
|                 null, LowerCasePreProcessor.class.getName()); | ||||
| 
 | ||||
|         SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder(); | ||||
|         sequenceSchemaBuilder.addColumnString("input"); | ||||
|         SequenceSchema schema = sequenceSchemaBuilder.build(); | ||||
|         assertEquals("input",schema.getName(0)); | ||||
| 
 | ||||
|         TransformProcess transformProcess = new TransformProcess.Builder(schema) | ||||
|                 .transform(transform) | ||||
|                 .build(); | ||||
| 
 | ||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); | ||||
| 
 | ||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); | ||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); | ||||
| 
 | ||||
|         assertEquals(exp.getRow(0, true), arr0); | ||||
|         assertEquals(exp.getRow(1, true), arr1); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,53 +0,0 @@ | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <configuration> | ||||
| 
 | ||||
|     <appender name="FILE" class="ch.qos.logback.core.FileAppender"> | ||||
|         <file>logs/application.log</file> | ||||
|         <encoder> | ||||
|             <pattern>%date - [%level] - from %logger in %thread | ||||
|                 %n%message%n%xException%n</pattern> | ||||
|         </encoder> | ||||
|     </appender> | ||||
| 
 | ||||
|     <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> | ||||
|         <encoder> | ||||
|             <pattern> %logger{15} - %message%n%xException{5} | ||||
|             </pattern> | ||||
|         </encoder> | ||||
|     </appender> | ||||
| 
 | ||||
|     <logger name="org.apache.catalina.core" level="DEBUG" /> | ||||
|     <logger name="org.springframework" level="DEBUG" /> | ||||
|     <logger name="org.datavec" level="DEBUG" /> | ||||
|     <logger name="org.nd4j" level="INFO" /> | ||||
|     <logger name="opennlp.uima.util" level="OFF" /> | ||||
|     <logger name="org.apache.uima" level="OFF" /> | ||||
|     <logger name="org.cleartk" level="OFF" /> | ||||
|     <logger name="org.apache.spark" level="WARN" /> | ||||
| 
 | ||||
| 
 | ||||
|     <root level="ERROR"> | ||||
|         <appender-ref ref="STDOUT" /> | ||||
|         <appender-ref ref="FILE" /> | ||||
|     </root> | ||||
| 
 | ||||
| </configuration> | ||||
| @ -1,56 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-data</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-geo</artifactId> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-api</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.maxmind.geoip2</groupId> | ||||
|             <artifactId>geoip2</artifactId> | ||||
|             <version>${geoip2.version}</version> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,25 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.geo; | ||||
| 
 | ||||
| public enum LocationType { | ||||
|     CITY, CITY_ID, CONTINENT, CONTINENT_ID, COUNTRY, COUNTRY_ID, COORDINATES, POSTAL_CODE, SUBDIVISIONS, SUBDIVISIONS_ID | ||||
| } | ||||
| @ -1,194 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.reduce.geo; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import org.datavec.api.transform.ReduceOp; | ||||
| import org.datavec.api.transform.metadata.ColumnMetaData; | ||||
| import org.datavec.api.transform.metadata.StringMetaData; | ||||
| import org.datavec.api.transform.ops.IAggregableReduceOp; | ||||
| import org.datavec.api.transform.reduce.AggregableColumnReduction; | ||||
| import org.datavec.api.transform.reduce.AggregableReductionUtils; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.common.function.Supplier; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class CoordinatesReduction implements AggregableColumnReduction { | ||||
|     public static final String DEFAULT_COLUMN_NAME = "CoordinatesReduction"; | ||||
| 
 | ||||
|     public final static String DEFAULT_DELIMITER = ":"; | ||||
|     protected String delimiter = DEFAULT_DELIMITER; | ||||
| 
 | ||||
|     private final List<String> columnNamesPostReduce; | ||||
| 
 | ||||
|     private final Supplier<IAggregableReduceOp<Writable, List<Writable>>> multiOp(final List<ReduceOp> ops) { | ||||
|         return new Supplier<IAggregableReduceOp<Writable, List<Writable>>>() { | ||||
|             @Override | ||||
|             public IAggregableReduceOp<Writable, List<Writable>> get() { | ||||
|                 return AggregableReductionUtils.reduceDoubleColumn(ops, false, null); | ||||
|             } | ||||
|         }; | ||||
|     } | ||||
| 
 | ||||
|     public CoordinatesReduction(String columnNamePostReduce, ReduceOp op) { | ||||
|         this(columnNamePostReduce, op, DEFAULT_DELIMITER); | ||||
|     } | ||||
| 
 | ||||
|     public CoordinatesReduction(List<String> columnNamePostReduce, List<ReduceOp> op) { | ||||
|         this(columnNamePostReduce, op, DEFAULT_DELIMITER); | ||||
|     } | ||||
| 
 | ||||
|     public CoordinatesReduction(String columnNamePostReduce, ReduceOp op, String delimiter) { | ||||
|         this(Collections.singletonList(columnNamePostReduce), Collections.singletonList(op), delimiter); | ||||
|     } | ||||
| 
 | ||||
|     public CoordinatesReduction(List<String> columnNamesPostReduce, List<ReduceOp> ops, String delimiter) { | ||||
|         this.columnNamesPostReduce = columnNamesPostReduce; | ||||
|         this.reducer = new CoordinateAggregableReduceOp(ops.size(), multiOp(ops), delimiter); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<String> getColumnsOutputName(String columnInputName) { | ||||
|         return columnNamesPostReduce; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<ColumnMetaData> getColumnOutputMetaData(List<String> newColumnName, ColumnMetaData columnInputMeta) { | ||||
|         List<ColumnMetaData> res = new ArrayList<>(newColumnName.size()); | ||||
|         for (String cn : newColumnName) | ||||
|             res.add(new StringMetaData((cn))); | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Schema transform(Schema inputSchema) { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setInputSchema(Schema inputSchema) { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Schema getInputSchema() { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String outputColumnName() { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String[] outputColumnNames() { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String[] columnNames() { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String columnName() { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     private IAggregableReduceOp<Writable, List<Writable>> reducer; | ||||
| 
 | ||||
|     @Override | ||||
|     public IAggregableReduceOp<Writable, List<Writable>> reduceOp() { | ||||
|         return reducer; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public static class CoordinateAggregableReduceOp implements IAggregableReduceOp<Writable, List<Writable>> { | ||||
| 
 | ||||
| 
 | ||||
|         private int nOps; | ||||
|         private Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOpValue; | ||||
|         @Getter | ||||
|         private ArrayList<IAggregableReduceOp<Writable, List<Writable>>> perCoordinateOps; // of size coords() | ||||
|         private String delimiter; | ||||
| 
 | ||||
|         public CoordinateAggregableReduceOp(int n, Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOp, | ||||
|                         String delim) { | ||||
|             this.nOps = n; | ||||
|             this.perCoordinateOps = new ArrayList<>(); | ||||
|             this.initialOpValue = initialOp; | ||||
|             this.delimiter = delim; | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         public <W extends IAggregableReduceOp<Writable, List<Writable>>> void combine(W accu) { | ||||
|             if (accu instanceof CoordinateAggregableReduceOp) { | ||||
|                 CoordinateAggregableReduceOp accumulator = (CoordinateAggregableReduceOp) accu; | ||||
|                 for (int i = 0; i < Math.min(perCoordinateOps.size(), accumulator.getPerCoordinateOps().size()); i++) { | ||||
|                     perCoordinateOps.get(i).combine(accumulator.getPerCoordinateOps().get(i)); | ||||
|                 } // the rest is assumed identical | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         public void accept(Writable writable) { | ||||
|             String[] coordinates = writable.toString().split(delimiter); | ||||
|             for (int i = 0; i < coordinates.length; i++) { | ||||
|                 String coordinate = coordinates[i]; | ||||
|                 while (perCoordinateOps.size() < i + 1) { | ||||
|                     perCoordinateOps.add(initialOpValue.get()); | ||||
|                 } | ||||
|                 perCoordinateOps.get(i).accept(new DoubleWritable(Double.parseDouble(coordinate))); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         public List<Writable> get() { | ||||
|             List<StringBuilder> res = new ArrayList<>(nOps); | ||||
|             for (int i = 0; i < nOps; i++) { | ||||
|                 res.add(new StringBuilder()); | ||||
|             } | ||||
| 
 | ||||
|             for (int i = 0; i < perCoordinateOps.size(); i++) { | ||||
|                 List<Writable> resThisCoord = perCoordinateOps.get(i).get(); | ||||
|                 for (int j = 0; j < nOps; j++) { | ||||
|                     res.get(j).append(resThisCoord.get(j).toString()); | ||||
|                     if (i < perCoordinateOps.size() - 1) { | ||||
|                         res.get(j).append(delimiter); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             List<Writable> finalRes = new ArrayList<>(nOps); | ||||
|             for (StringBuilder sb : res) { | ||||
|                 finalRes.add(new Text(sb.toString())); | ||||
|             } | ||||
|             return finalRes; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,117 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.transform.geo; | ||||
| 
 | ||||
| import org.datavec.api.transform.MathOp; | ||||
| import org.datavec.api.transform.metadata.ColumnMetaData; | ||||
| import org.datavec.api.transform.metadata.DoubleMetaData; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class CoordinatesDistanceTransform extends BaseColumnsMathOpTransform { | ||||
| 
 | ||||
|     public final static String DEFAULT_DELIMITER = ":"; | ||||
|     protected String delimiter = DEFAULT_DELIMITER; | ||||
| 
 | ||||
|     public CoordinatesDistanceTransform(String newColumnName, String firstColumn, String secondColumn, | ||||
|                     String stdevColumn) { | ||||
|         this(newColumnName, firstColumn, secondColumn, stdevColumn, DEFAULT_DELIMITER); | ||||
|     } | ||||
| 
 | ||||
|     public CoordinatesDistanceTransform(@JsonProperty("newColumnName") String newColumnName, | ||||
|                     @JsonProperty("firstColumn") String firstColumn, @JsonProperty("secondColumn") String secondColumn, | ||||
|                     @JsonProperty("stdevColumn") String stdevColumn, @JsonProperty("delimiter") String delimiter) { | ||||
|         super(newColumnName, MathOp.Add /* dummy op */, | ||||
|                         stdevColumn != null ? new String[] {firstColumn, secondColumn, stdevColumn} | ||||
|                                         : new String[] {firstColumn, secondColumn}); | ||||
|         this.delimiter = delimiter; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected ColumnMetaData derivedColumnMetaData(String newColumnName, Schema inputSchema) { | ||||
|         return new DoubleMetaData(newColumnName); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected Writable doOp(Writable... input) { | ||||
|         String[] first = input[0].toString().split(delimiter); | ||||
|         String[] second = input[1].toString().split(delimiter); | ||||
|         String[] stdev = columns.length > 2 ? input[2].toString().split(delimiter) : null; | ||||
| 
 | ||||
|         double dist = 0; | ||||
|         for (int i = 0; i < first.length; i++) { | ||||
|             double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]); | ||||
|             double s = stdev != null ? Double.parseDouble(stdev[i]) : 1; | ||||
|             dist += (d * d) / (s * s); | ||||
|         } | ||||
|         return new DoubleWritable(Math.sqrt(dist)); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         return "CoordinatesDistanceTransform(newColumnName=\"" + newColumnName + "\",columns=" | ||||
|                         + Arrays.toString(columns) + ",delimiter=" + delimiter + ")"; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Transform an object | ||||
|      * in to another object | ||||
|      * | ||||
|      * @param input the record to transform | ||||
|      * @return the transformed writable | ||||
|      */ | ||||
|     @Override | ||||
|     public Object map(Object input) { | ||||
|         List row = (List) input; | ||||
|         String[] first = row.get(0).toString().split(delimiter); | ||||
|         String[] second = row.get(1).toString().split(delimiter); | ||||
|         String[] stdev = columns.length > 2 ? row.get(2).toString().split(delimiter) : null; | ||||
| 
 | ||||
|         double dist = 0; | ||||
|         for (int i = 0; i < first.length; i++) { | ||||
|             double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]); | ||||
|             double s = stdev != null ? Double.parseDouble(stdev[i]) : 1; | ||||
|             dist += (d * d) / (s * s); | ||||
|         } | ||||
|         return Math.sqrt(dist); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Transform a sequence | ||||
|      * | ||||
|      * @param sequence | ||||
|      */ | ||||
|     @Override | ||||
|     public Object mapSequence(Object sequence) { | ||||
|         List<List> seq = (List<List>) sequence; | ||||
|         List<Double> ret = new ArrayList<>(); | ||||
|         for (Object step : seq) | ||||
|             ret.add((Double) map(step)); | ||||
|         return ret; | ||||
|     } | ||||
| } | ||||
| @ -1,70 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.transform.geo; | ||||
| 
 | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.common.util.ArchiveUtils; | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.net.URL; | ||||
| 
 | ||||
| public class GeoIPFetcher { | ||||
|     protected static final Logger log = LoggerFactory.getLogger(GeoIPFetcher.class); | ||||
| 
 | ||||
|     /** Default directory for <a href="http://dev.maxmind.com/geoip/geoipupdate/">http://dev.maxmind.com/geoip/geoipupdate/</a> */ | ||||
|     public static final String GEOIP_DIR = "/usr/local/share/GeoIP/"; | ||||
|     public static final String GEOIP_DIR2 = System.getProperty("user.home") + "/.datavec-geoip"; | ||||
| 
 | ||||
|     public static final String CITY_DB = "GeoIP2-City.mmdb"; | ||||
|     public static final String CITY_LITE_DB = "GeoLite2-City.mmdb"; | ||||
| 
 | ||||
|     public static final String CITY_LITE_URL = | ||||
|                     "http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz"; | ||||
| 
 | ||||
|     public static synchronized File fetchCityDB() throws IOException { | ||||
|         File cityFile = new File(GEOIP_DIR, CITY_DB); | ||||
|         if (cityFile.isFile()) { | ||||
|             return cityFile; | ||||
|         } | ||||
|         cityFile = new File(GEOIP_DIR, CITY_LITE_DB); | ||||
|         if (cityFile.isFile()) { | ||||
|             return cityFile; | ||||
|         } | ||||
|         cityFile = new File(GEOIP_DIR2, CITY_LITE_DB); | ||||
|         if (cityFile.isFile()) { | ||||
|             return cityFile; | ||||
|         } | ||||
| 
 | ||||
|         log.info("Downloading GeoLite2 City database..."); | ||||
|         File archive = new File(GEOIP_DIR2, CITY_LITE_DB + ".gz"); | ||||
|         File dir = new File(GEOIP_DIR2); | ||||
|         dir.mkdirs(); | ||||
|         FileUtils.copyURLToFile(new URL(CITY_LITE_URL), archive); | ||||
|         ArchiveUtils.unzipFileTo(archive.getAbsolutePath(), dir.getAbsolutePath()); | ||||
|         Preconditions.checkState(cityFile.isFile(), "Error extracting files: expected city file does not exist after extraction"); | ||||
| 
 | ||||
|         return cityFile; | ||||
|     } | ||||
| } | ||||
| @ -1,43 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.transform.geo; | ||||
| 
 | ||||
| import org.datavec.api.transform.geo.LocationType; | ||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| public class IPAddressToCoordinatesTransform extends IPAddressToLocationTransform { | ||||
| 
 | ||||
|     public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName) throws IOException { | ||||
|         this(columnName, DEFAULT_DELIMITER); | ||||
|     } | ||||
| 
 | ||||
|     public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName, | ||||
|                     @JsonProperty("delimiter") String delimiter) throws IOException { | ||||
|         super(columnName, LocationType.COORDINATES, delimiter); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         return "IPAddressToCoordinatesTransform"; | ||||
|     } | ||||
| } | ||||
| @ -1,184 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.transform.geo; | ||||
| 
 | ||||
| import com.maxmind.geoip2.DatabaseReader; | ||||
| import com.maxmind.geoip2.exception.GeoIp2Exception; | ||||
| import com.maxmind.geoip2.model.CityResponse; | ||||
| import com.maxmind.geoip2.record.Location; | ||||
| import com.maxmind.geoip2.record.Subdivision; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.datavec.api.transform.geo.LocationType; | ||||
| import org.datavec.api.transform.metadata.ColumnMetaData; | ||||
| import org.datavec.api.transform.metadata.StringMetaData; | ||||
| import org.datavec.api.transform.transform.BaseColumnTransform; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.ObjectInputStream; | ||||
| import java.io.ObjectOutputStream; | ||||
| import java.net.InetAddress; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class IPAddressToLocationTransform extends BaseColumnTransform { | ||||
|     /** | ||||
|      * Name of the system property to use when configuring the GeoIP database file.<br> | ||||
|      * Most users don't need to set this - typically used for testing purposes.<br> | ||||
|      * Set with the full local path, like: "C:/datavec-geo/GeoIP2-City-Test.mmdb" | ||||
|      */ | ||||
|     public static final String GEOIP_FILE_PROPERTY = "org.datavec.geoip.file"; | ||||
| 
 | ||||
|     private static File database; | ||||
|     private static DatabaseReader reader; | ||||
| 
 | ||||
|     public final static String DEFAULT_DELIMITER = ":"; | ||||
|     protected String delimiter = DEFAULT_DELIMITER; | ||||
|     protected LocationType locationType; | ||||
| 
 | ||||
|     private static synchronized void init() throws IOException { | ||||
|         // A File object pointing to your GeoIP2 or GeoLite2 database: | ||||
|         // http://dev.maxmind.com/geoip/geoip2/geolite2/ | ||||
|         if (database == null) { | ||||
|             String s = System.getProperty(GEOIP_FILE_PROPERTY); | ||||
|             if(s != null && !s.isEmpty()){ | ||||
|                 //Use user-specified GEOIP file - mainly for testing purposes | ||||
|                 File f = new File(s); | ||||
|                 if(f.exists() && f.isFile()){ | ||||
|                     database = f; | ||||
|                 } else { | ||||
|                     log.warn("GeoIP file (system property {}) is set to \"{}\" but this is not a valid file, using default database", GEOIP_FILE_PROPERTY, s); | ||||
|                     database = GeoIPFetcher.fetchCityDB(); | ||||
|                 } | ||||
|             } else { | ||||
|                 database = GeoIPFetcher.fetchCityDB(); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // This creates the DatabaseReader object, which should be reused across lookups. | ||||
|         if (reader == null) { | ||||
|             reader = new DatabaseReader.Builder(database).build(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public IPAddressToLocationTransform(String columnName) throws IOException { | ||||
|         this(columnName, LocationType.CITY); | ||||
|     } | ||||
| 
 | ||||
|     public IPAddressToLocationTransform(String columnName, LocationType locationType) throws IOException { | ||||
|         this(columnName, locationType, DEFAULT_DELIMITER); | ||||
|     } | ||||
| 
 | ||||
|     public IPAddressToLocationTransform(@JsonProperty("columnName") String columnName, | ||||
|                     @JsonProperty("delimiter") LocationType locationType, @JsonProperty("delimiter") String delimiter) | ||||
|                     throws IOException { | ||||
|         super(columnName); | ||||
|         this.delimiter = delimiter; | ||||
|         this.locationType = locationType; | ||||
|         init(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { | ||||
|         return new StringMetaData(newName); //Output after transform: String (Text) | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Writable map(Writable columnWritable) { | ||||
|         try { | ||||
|             InetAddress ipAddress = InetAddress.getByName(columnWritable.toString()); | ||||
|             CityResponse response = reader.city(ipAddress); | ||||
|             String text = ""; | ||||
|             switch (locationType) { | ||||
|                 case CITY: | ||||
|                     text = response.getCity().getName(); | ||||
|                     break; | ||||
|                 case CITY_ID: | ||||
|                     text = response.getCity().getGeoNameId().toString(); | ||||
|                     break; | ||||
|                 case CONTINENT: | ||||
|                     text = response.getContinent().getName(); | ||||
|                     break; | ||||
|                 case CONTINENT_ID: | ||||
|                     text = response.getContinent().getGeoNameId().toString(); | ||||
|                     break; | ||||
|                 case COUNTRY: | ||||
|                     text = response.getCountry().getName(); | ||||
|                     break; | ||||
|                 case COUNTRY_ID: | ||||
|                     text = response.getCountry().getGeoNameId().toString(); | ||||
|                     break; | ||||
|                 case COORDINATES: | ||||
|                     Location location = response.getLocation(); | ||||
|                     text = location.getLatitude() + delimiter + location.getLongitude(); | ||||
|                     break; | ||||
|                 case POSTAL_CODE: | ||||
|                     text = response.getPostal().getCode(); | ||||
|                     break; | ||||
|                 case SUBDIVISIONS: | ||||
|                     for (Subdivision s : response.getSubdivisions()) { | ||||
|                         if (text.length() > 0) { | ||||
|                             text += delimiter; | ||||
|                         } | ||||
|                         text += s.getName(); | ||||
|                     } | ||||
|                     break; | ||||
|                 case SUBDIVISIONS_ID: | ||||
|                     for (Subdivision s : response.getSubdivisions()) { | ||||
|                         if (text.length() > 0) { | ||||
|                             text += delimiter; | ||||
|                         } | ||||
|                         text += s.getGeoNameId().toString(); | ||||
|                     } | ||||
|                     break; | ||||
|                 default: | ||||
|                     assert false; | ||||
|             } | ||||
|             if(text == null) | ||||
|                 text = ""; | ||||
|             return new Text(text); | ||||
|         } catch (GeoIp2Exception | IOException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         return "IPAddressToLocationTransform"; | ||||
|     } | ||||
| 
 | ||||
|     //Custom serialization methods, because GeoIP2 doesn't allow DatabaseReader objects to be serialized :( | ||||
|     private void writeObject(ObjectOutputStream out) throws IOException { | ||||
|         out.defaultWriteObject(); | ||||
|     } | ||||
| 
 | ||||
|     private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { | ||||
|         in.defaultReadObject(); | ||||
|         init(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Object map(Object input) { | ||||
|         return null; | ||||
|     } | ||||
| } | ||||
| @ -1,46 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| package org.datavec.api.transform; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { | ||||
| 
 | ||||
|     @Override | ||||
|     protected Set<Class<?>> getExclusions() { | ||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) | ||||
| 	    return new HashSet<>(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| 	protected String getPackageName() { | ||||
|     	return "org.datavec.api.transform"; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Class<?> getBaseClass() { | ||||
|     	return BaseND4JTest.class; | ||||
| 	} | ||||
| } | ||||
| @ -1,80 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.reduce; | ||||
| 
 | ||||
| import org.datavec.api.transform.ColumnType; | ||||
| import org.datavec.api.transform.ReduceOp; | ||||
| import org.datavec.api.transform.ops.IAggregableReduceOp; | ||||
| import org.datavec.api.transform.reduce.geo.CoordinatesReduction; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.Test; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| /** | ||||
|  * @author saudet | ||||
|  */ | ||||
| public class TestGeoReduction { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCustomReductions() { | ||||
| 
 | ||||
|         List<List<Writable>> inputs = new ArrayList<>(); | ||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1#5"))); | ||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2#6"))); | ||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3#7"))); | ||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4#8"))); | ||||
| 
 | ||||
|         List<Writable> expected = Arrays.asList((Writable) new Text("someKey"), new Text("10.0#26.0")); | ||||
| 
 | ||||
|         Schema schema = new Schema.Builder().addColumnString("key").addColumnString("coord").build(); | ||||
| 
 | ||||
|         Reducer reducer = new Reducer.Builder(ReduceOp.Count).keyColumns("key") | ||||
|                         .customReduction("coord", new CoordinatesReduction("coordSum", ReduceOp.Sum, "#")).build(); | ||||
| 
 | ||||
|         reducer.setInputSchema(schema); | ||||
| 
 | ||||
|         IAggregableReduceOp<List<Writable>, List<Writable>> aggregableReduceOp = reducer.aggregableReducer(); | ||||
|         for (List<Writable> l : inputs) | ||||
|             aggregableReduceOp.accept(l); | ||||
|         List<Writable> out = aggregableReduceOp.get(); | ||||
| 
 | ||||
|         assertEquals(2, out.size()); | ||||
|         assertEquals(expected, out); | ||||
| 
 | ||||
|         //Check schema: | ||||
|         String[] expNames = new String[] {"key", "coordSum"}; | ||||
|         ColumnType[] expTypes = new ColumnType[] {ColumnType.String, ColumnType.String}; | ||||
|         Schema outSchema = reducer.transform(schema); | ||||
| 
 | ||||
|         assertEquals(2, outSchema.numColumns()); | ||||
|         for (int i = 0; i < 2; i++) { | ||||
|             assertEquals(expNames[i], outSchema.getName(i)); | ||||
|             assertEquals(expTypes[i], outSchema.getType(i)); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -1,153 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.api.transform.transform; | ||||
| 
 | ||||
| import org.datavec.api.transform.ColumnType; | ||||
| import org.datavec.api.transform.Transform; | ||||
| import org.datavec.api.transform.geo.LocationType; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform; | ||||
| import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform; | ||||
| import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| 
 | ||||
| import java.io.*; | ||||
| import java.util.Arrays; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| /** | ||||
|  * @author saudet | ||||
|  */ | ||||
| public class TestGeoTransforms { | ||||
| 
 | ||||
|     @BeforeClass | ||||
|     public static void beforeClass() throws Exception { | ||||
|         //Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing | ||||
|         File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile(); | ||||
|         System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath()); | ||||
|     } | ||||
| 
 | ||||
|     @AfterClass | ||||
|     public static void afterClass(){ | ||||
|         System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, ""); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCoordinatesDistanceTransform() throws Exception { | ||||
|         Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev") | ||||
|                         .build(); | ||||
| 
 | ||||
|         Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|"); | ||||
|         transform.setInputSchema(schema); | ||||
| 
 | ||||
|         Schema out = transform.transform(schema); | ||||
|         assertEquals(4, out.numColumns()); | ||||
|         assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames()); | ||||
|         assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double), | ||||
|                         out.getColumnTypes()); | ||||
| 
 | ||||
|         assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)), | ||||
|                         transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10")))); | ||||
|         assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"), | ||||
|                         new DoubleWritable(Math.sqrt(160))), | ||||
|                         transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), | ||||
|                                         new Text("10|5")))); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testIPAddressToCoordinatesTransform() throws Exception { | ||||
|         Schema schema = new Schema.Builder().addColumnString("column").build(); | ||||
| 
 | ||||
|         Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER"); | ||||
|         transform.setInputSchema(schema); | ||||
| 
 | ||||
|         Schema out = transform.transform(schema); | ||||
| 
 | ||||
|         assertEquals(1, out.getColumnMetaData().size()); | ||||
|         assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); | ||||
| 
 | ||||
|         String in = "81.2.69.160"; | ||||
|         double latitude = 51.5142; | ||||
|         double longitude = -0.0931; | ||||
| 
 | ||||
|         List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in))); | ||||
|         assertEquals(1, writables.size()); | ||||
|         String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); | ||||
|         assertEquals(2, coordinates.length); | ||||
|         assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1); | ||||
|         assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1); | ||||
| 
 | ||||
|         //Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/ | ||||
|         ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||||
|         ObjectOutputStream oos = new ObjectOutputStream(baos); | ||||
|         oos.writeObject(transform); | ||||
| 
 | ||||
|         byte[] bytes = baos.toByteArray(); | ||||
| 
 | ||||
|         ByteArrayInputStream bais = new ByteArrayInputStream(bytes); | ||||
|         ObjectInputStream ois = new ObjectInputStream(bais); | ||||
| 
 | ||||
|         Transform deserialized = (Transform) ois.readObject(); | ||||
|         writables = deserialized.map(Collections.singletonList((Writable) new Text(in))); | ||||
|         assertEquals(1, writables.size()); | ||||
|         coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); | ||||
|         //System.out.println(Arrays.toString(coordinates)); | ||||
|         assertEquals(2, coordinates.length); | ||||
|         assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1); | ||||
|         assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testIPAddressToLocationTransform() throws Exception { | ||||
|         Schema schema = new Schema.Builder().addColumnString("column").build(); | ||||
|         LocationType[] locationTypes = LocationType.values(); | ||||
|         String in = "81.2.69.160"; | ||||
|         String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167", | ||||
|                         "51.5142:-0.0931", "", "England", "6269131"};    //Note: no postcode in this test DB for this record | ||||
| 
 | ||||
|         for (int i = 0; i < locationTypes.length; i++) { | ||||
|             LocationType locationType = locationTypes[i]; | ||||
|             String location = locations[i]; | ||||
| 
 | ||||
|             Transform transform = new IPAddressToLocationTransform("column", locationType); | ||||
|             transform.setInputSchema(schema); | ||||
| 
 | ||||
|             Schema out = transform.transform(schema); | ||||
| 
 | ||||
|             assertEquals(1, out.getColumnMetaData().size()); | ||||
|             assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); | ||||
| 
 | ||||
|             List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in))); | ||||
|             assertEquals(1, writables.size()); | ||||
|             assertEquals(location, writables.get(0).toString()); | ||||
|             //System.out.println(location); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -37,11 +37,7 @@ | ||||
|     <name>datavec-data</name> | ||||
| 
 | ||||
|     <modules> | ||||
|         <module>datavec-data-audio</module> | ||||
|         <module>datavec-data-codec</module> | ||||
|         <module>datavec-data-image</module> | ||||
|         <module>datavec-data-nlp</module> | ||||
|         <module>datavec-geo</module> | ||||
|     </modules> | ||||
| 
 | ||||
|     <dependencyManagement> | ||||
|  | ||||
| @ -1,64 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-spark-inference-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-spark-inference-client</artifactId> | ||||
| 
 | ||||
|     <name>datavec-spark-inference-client</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-spark-inference-server_2.11</artifactId> | ||||
|             <version>1.0.0-SNAPSHOT</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-spark-inference-model</artifactId> | ||||
|             <version>${project.parent.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.mashape.unirest</groupId> | ||||
|             <artifactId>unirest-java</artifactId> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,292 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.client; | ||||
| 
 | ||||
| 
 | ||||
| import com.mashape.unirest.http.ObjectMapper; | ||||
| import com.mashape.unirest.http.Unirest; | ||||
| import com.mashape.unirest.http.exceptions.UnirestException; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.model.model.*; | ||||
| import org.datavec.spark.inference.model.service.DataVecTransformService; | ||||
| import org.nd4j.shade.jackson.core.JsonProcessingException; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| @AllArgsConstructor | ||||
| @Slf4j | ||||
| public class DataVecTransformClient implements DataVecTransformService { | ||||
|     private String url; | ||||
| 
 | ||||
|     static { | ||||
|         // Only one time | ||||
|         Unirest.setObjectMapper(new ObjectMapper() { | ||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = | ||||
|                     new org.nd4j.shade.jackson.databind.ObjectMapper(); | ||||
| 
 | ||||
|             public <T> T readValue(String value, Class<T> valueType) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.readValue(value, valueType); | ||||
|                 } catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             public String writeValue(Object value) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.writeValueAsString(value); | ||||
|                 } catch (JsonProcessingException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transformProcess | ||||
|      */ | ||||
|     @Override | ||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { | ||||
|         try { | ||||
|             String s = transformProcess.toJson(); | ||||
|             Unirest.post(url + "/transformprocess").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json").body(s).asJson(); | ||||
| 
 | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in setCSVTransformProcess()", e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public TransformProcess getCSVTransformProcess() { | ||||
|         try { | ||||
|             String s = Unirest.get(url + "/transformprocess").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json").asString().getBody(); | ||||
|             return TransformProcess.fromJson(s); | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in getCSVTransformProcess()",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public ImageTransformProcess getImageTransformProcess() { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { | ||||
|         try { | ||||
|             SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .body(transform).asObject(SingleCSVRecord.class).getBody(); | ||||
|             return singleCsvRecord; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformIncremental(SingleCSVRecord)",e); | ||||
|         } | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"TRUE") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(SequenceBatchCSVRecord.class) | ||||
|                     .getBody(); | ||||
|             return batchCSVRecord1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"FALSE") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(BatchCSVRecord.class) | ||||
|                     .getBody(); | ||||
|             return batchCSVRecord1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transform(BatchCSVRecord)", e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json").body(batchCSVRecord) | ||||
|                     .asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return batchArray1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformArray(BatchCSVRecord)",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") | ||||
|                     .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return array; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return array; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequenceArrayIncremental",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return batchArray1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequenceArray",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(SequenceBatchCSVRecord.class).getBody(); | ||||
|             return batchCSVRecord1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequence"); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { | ||||
|         try { | ||||
|             SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(transform).asObject(SequenceBatchCSVRecord.class).getBody(); | ||||
|             return singleCsvRecord; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequenceIncremental"); | ||||
|         } | ||||
|         return null; | ||||
|     } | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user