Merge branch 'master' of https://github.com/eclipse/deeplearning4j into ag_github_workflows_1
This commit is contained in:
		
						commit
						4a06f39085
					
				| @ -1,77 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-data</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-data-audio</artifactId> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-data-audio</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-api</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.bytedeco</groupId> |  | ||||||
|             <artifactId>javacpp</artifactId> |  | ||||||
|             <version>${javacpp.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.bytedeco</groupId> |  | ||||||
|             <artifactId>javacv</artifactId> |  | ||||||
|             <version>${javacv.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.github.wendykierp</groupId> |  | ||||||
|             <artifactId>JTransforms</artifactId> |  | ||||||
|             <version>${jtransforms.version}</version> |  | ||||||
|             <classifier>with-dependencies</classifier> |  | ||||||
|         </dependency> |  | ||||||
|         <!-- Do not depend on FFmpeg by default due to licensing concerns. --> |  | ||||||
|         <!-- |  | ||||||
|                 <dependency> |  | ||||||
|                     <groupId>org.bytedeco</groupId> |  | ||||||
|                     <artifactId>ffmpeg-platform</artifactId> |  | ||||||
|                     <version>${ffmpeg.version}-${javacpp-presets.version}</version> |  | ||||||
|                 </dependency> |  | ||||||
|         --> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,329 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.audio.extension.NormalizedSampleAmplitudes; |  | ||||||
| import org.datavec.audio.extension.Spectrogram; |  | ||||||
| import org.datavec.audio.fingerprint.FingerprintManager; |  | ||||||
| import org.datavec.audio.fingerprint.FingerprintSimilarity; |  | ||||||
| import org.datavec.audio.fingerprint.FingerprintSimilarityComputer; |  | ||||||
| 
 |  | ||||||
| import java.io.FileInputStream; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| public class Wave implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private static final long serialVersionUID = 1L; |  | ||||||
|     private WaveHeader waveHeader; |  | ||||||
|     private byte[] data; // little endian |  | ||||||
|     private byte[] fingerprint; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      *  |  | ||||||
|      */ |  | ||||||
|     public Wave() { |  | ||||||
|         this.waveHeader = new WaveHeader(); |  | ||||||
|         this.data = new byte[0]; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      *  |  | ||||||
|      * @param filename |  | ||||||
|      *            Wave file |  | ||||||
|      */ |  | ||||||
|     public Wave(String filename) { |  | ||||||
|         try { |  | ||||||
|             InputStream inputStream = new FileInputStream(filename); |  | ||||||
|             initWaveWithInputStream(inputStream); |  | ||||||
|             inputStream.close(); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             System.out.println(e.toString()); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      *  |  | ||||||
|      * @param inputStream |  | ||||||
|      *            Wave file input stream |  | ||||||
|      */ |  | ||||||
|     public Wave(InputStream inputStream) { |  | ||||||
|         initWaveWithInputStream(inputStream); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      *  |  | ||||||
|      * @param waveHeader |  | ||||||
|      * @param data |  | ||||||
|      */ |  | ||||||
|     public Wave(WaveHeader waveHeader, byte[] data) { |  | ||||||
|         this.waveHeader = waveHeader; |  | ||||||
|         this.data = data; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void initWaveWithInputStream(InputStream inputStream) { |  | ||||||
|         // reads the first 44 bytes for header |  | ||||||
|         waveHeader = new WaveHeader(inputStream); |  | ||||||
| 
 |  | ||||||
|         if (waveHeader.isValid()) { |  | ||||||
|             // load data |  | ||||||
|             try { |  | ||||||
|                 data = new byte[inputStream.available()]; |  | ||||||
|                 inputStream.read(data); |  | ||||||
|             } catch (IOException e) { |  | ||||||
|                 System.err.println(e.toString()); |  | ||||||
|             } |  | ||||||
|             // end load data |  | ||||||
|         } else { |  | ||||||
|             System.err.println("Invalid Wave Header"); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Trim the wave data |  | ||||||
|      *  |  | ||||||
|      * @param leftTrimNumberOfSample |  | ||||||
|      *            Number of sample trimmed from beginning |  | ||||||
|      * @param rightTrimNumberOfSample |  | ||||||
|      *            Number of sample trimmed from ending |  | ||||||
|      */ |  | ||||||
|     public void trim(int leftTrimNumberOfSample, int rightTrimNumberOfSample) { |  | ||||||
| 
 |  | ||||||
|         long chunkSize = waveHeader.getChunkSize(); |  | ||||||
|         long subChunk2Size = waveHeader.getSubChunk2Size(); |  | ||||||
| 
 |  | ||||||
|         long totalTrimmed = leftTrimNumberOfSample + rightTrimNumberOfSample; |  | ||||||
| 
 |  | ||||||
|         if (totalTrimmed > subChunk2Size) { |  | ||||||
|             leftTrimNumberOfSample = (int) subChunk2Size; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // update wav info |  | ||||||
|         chunkSize -= totalTrimmed; |  | ||||||
|         subChunk2Size -= totalTrimmed; |  | ||||||
| 
 |  | ||||||
|         if (chunkSize >= 0 && subChunk2Size >= 0) { |  | ||||||
|             waveHeader.setChunkSize(chunkSize); |  | ||||||
|             waveHeader.setSubChunk2Size(subChunk2Size); |  | ||||||
| 
 |  | ||||||
|             byte[] trimmedData = new byte[(int) subChunk2Size]; |  | ||||||
|             System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size); |  | ||||||
|             data = trimmedData; |  | ||||||
|         } else { |  | ||||||
|             System.err.println("Trim error: Negative length"); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Trim the wave data from beginning |  | ||||||
|      *  |  | ||||||
|      * @param numberOfSample |  | ||||||
|      *            numberOfSample trimmed from beginning |  | ||||||
|      */ |  | ||||||
|     public void leftTrim(int numberOfSample) { |  | ||||||
|         trim(numberOfSample, 0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Trim the wave data from ending |  | ||||||
|      *  |  | ||||||
|      * @param numberOfSample |  | ||||||
|      *            numberOfSample trimmed from ending |  | ||||||
|      */ |  | ||||||
|     public void rightTrim(int numberOfSample) { |  | ||||||
|         trim(0, numberOfSample); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Trim the wave data |  | ||||||
|      *  |  | ||||||
|      * @param leftTrimSecond |  | ||||||
|      *            Seconds trimmed from beginning |  | ||||||
|      * @param rightTrimSecond |  | ||||||
|      *            Seconds trimmed from ending |  | ||||||
|      */ |  | ||||||
|     public void trim(double leftTrimSecond, double rightTrimSecond) { |  | ||||||
| 
 |  | ||||||
|         int sampleRate = waveHeader.getSampleRate(); |  | ||||||
|         int bitsPerSample = waveHeader.getBitsPerSample(); |  | ||||||
|         int channels = waveHeader.getChannels(); |  | ||||||
| 
 |  | ||||||
|         int leftTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * leftTrimSecond); |  | ||||||
|         int rightTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * rightTrimSecond); |  | ||||||
| 
 |  | ||||||
|         trim(leftTrimNumberOfSample, rightTrimNumberOfSample); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Trim the wave data from beginning |  | ||||||
|      *  |  | ||||||
|      * @param second |  | ||||||
|      *            Seconds trimmed from beginning |  | ||||||
|      */ |  | ||||||
|     public void leftTrim(double second) { |  | ||||||
|         trim(second, 0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Trim the wave data from ending |  | ||||||
|      *  |  | ||||||
|      * @param second |  | ||||||
|      *            Seconds trimmed from ending |  | ||||||
|      */ |  | ||||||
|     public void rightTrim(double second) { |  | ||||||
|         trim(0, second); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the wave header |  | ||||||
|      *  |  | ||||||
|      * @return waveHeader |  | ||||||
|      */ |  | ||||||
|     public WaveHeader getWaveHeader() { |  | ||||||
|         return waveHeader; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the wave spectrogram |  | ||||||
|      *  |  | ||||||
|      * @return spectrogram |  | ||||||
|      */ |  | ||||||
|     public Spectrogram getSpectrogram() { |  | ||||||
|         return new Spectrogram(this); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the wave spectrogram |  | ||||||
|      *  |  | ||||||
|      * @param fftSampleSize	number of sample in fft, the value needed to be a number to power of 2 |  | ||||||
|      * @param overlapFactor	1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping |  | ||||||
|      *  |  | ||||||
|      * @return spectrogram |  | ||||||
|      */ |  | ||||||
|     public Spectrogram getSpectrogram(int fftSampleSize, int overlapFactor) { |  | ||||||
|         return new Spectrogram(this, fftSampleSize, overlapFactor); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the wave data in bytes |  | ||||||
|      *  |  | ||||||
|      * @return wave data |  | ||||||
|      */ |  | ||||||
|     public byte[] getBytes() { |  | ||||||
|         return data; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Data byte size of the wave excluding header size |  | ||||||
|      *  |  | ||||||
|      * @return byte size of the wave |  | ||||||
|      */ |  | ||||||
|     public int size() { |  | ||||||
|         return data.length; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Length of the wave in second |  | ||||||
|      *  |  | ||||||
|      * @return length in second |  | ||||||
|      */ |  | ||||||
|     public float length() { |  | ||||||
|         return (float) waveHeader.getSubChunk2Size() / waveHeader.getByteRate(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Timestamp of the wave length |  | ||||||
|      *  |  | ||||||
|      * @return timestamp |  | ||||||
|      */ |  | ||||||
|     public String timestamp() { |  | ||||||
|         float totalSeconds = this.length(); |  | ||||||
|         float second = totalSeconds % 60; |  | ||||||
|         int minute = (int) totalSeconds / 60 % 60; |  | ||||||
|         int hour = (int) (totalSeconds / 3600); |  | ||||||
| 
 |  | ||||||
|         StringBuilder sb = new StringBuilder(); |  | ||||||
|         if (hour > 0) { |  | ||||||
|             sb.append(hour + ":"); |  | ||||||
|         } |  | ||||||
|         if (minute > 0) { |  | ||||||
|             sb.append(minute + ":"); |  | ||||||
|         } |  | ||||||
|         sb.append(second); |  | ||||||
| 
 |  | ||||||
|         return sb.toString(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the amplitudes of the wave samples (depends on the header) |  | ||||||
|      *  |  | ||||||
|      * @return amplitudes array (signed 16-bit) |  | ||||||
|      */ |  | ||||||
|     public short[] getSampleAmplitudes() { |  | ||||||
|         int bytePerSample = waveHeader.getBitsPerSample() / 8; |  | ||||||
|         int numSamples = data.length / bytePerSample; |  | ||||||
|         short[] amplitudes = new short[numSamples]; |  | ||||||
| 
 |  | ||||||
|         int pointer = 0; |  | ||||||
|         for (int i = 0; i < numSamples; i++) { |  | ||||||
|             short amplitude = 0; |  | ||||||
|             for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) { |  | ||||||
|                 // little endian |  | ||||||
|                 amplitude |= (short) ((data[pointer++] & 0xFF) << (byteNumber * 8)); |  | ||||||
|             } |  | ||||||
|             amplitudes[i] = amplitude; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return amplitudes; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String toString() { |  | ||||||
|         StringBuilder sb = new StringBuilder(waveHeader.toString()); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("length: " + timestamp()); |  | ||||||
|         return sb.toString(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double[] getNormalizedAmplitudes() { |  | ||||||
|         NormalizedSampleAmplitudes amplitudes = new NormalizedSampleAmplitudes(this); |  | ||||||
|         return amplitudes.getNormalizedAmplitudes(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public byte[] getFingerprint() { |  | ||||||
|         if (fingerprint == null) { |  | ||||||
|             FingerprintManager fingerprintManager = new FingerprintManager(); |  | ||||||
|             fingerprint = fingerprintManager.extractFingerprint(this); |  | ||||||
|         } |  | ||||||
|         return fingerprint; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public FingerprintSimilarity getFingerprintSimilarity(Wave wave) { |  | ||||||
|         FingerprintSimilarityComputer fingerprintSimilarityComputer = |  | ||||||
|                         new FingerprintSimilarityComputer(this.getFingerprint(), wave.getFingerprint()); |  | ||||||
|         return fingerprintSimilarityComputer.getFingerprintsSimilarity(); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,98 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| 
 |  | ||||||
| import java.io.FileOutputStream; |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class WaveFileManager { |  | ||||||
| 
 |  | ||||||
|     private Wave wave; |  | ||||||
| 
 |  | ||||||
|     public WaveFileManager() { |  | ||||||
|         wave = new Wave(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public WaveFileManager(Wave wave) { |  | ||||||
|         setWave(wave); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Save the wave file |  | ||||||
|      *  |  | ||||||
|      * @param filename |  | ||||||
|      *            filename to be saved |  | ||||||
|      *             |  | ||||||
|      * @see	Wave file saved |  | ||||||
|      */ |  | ||||||
|     public void saveWaveAsFile(String filename) { |  | ||||||
| 
 |  | ||||||
|         WaveHeader waveHeader = wave.getWaveHeader(); |  | ||||||
| 
 |  | ||||||
|         int byteRate = waveHeader.getByteRate(); |  | ||||||
|         int audioFormat = waveHeader.getAudioFormat(); |  | ||||||
|         int sampleRate = waveHeader.getSampleRate(); |  | ||||||
|         int bitsPerSample = waveHeader.getBitsPerSample(); |  | ||||||
|         int channels = waveHeader.getChannels(); |  | ||||||
|         long chunkSize = waveHeader.getChunkSize(); |  | ||||||
|         long subChunk1Size = waveHeader.getSubChunk1Size(); |  | ||||||
|         long subChunk2Size = waveHeader.getSubChunk2Size(); |  | ||||||
|         int blockAlign = waveHeader.getBlockAlign(); |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             FileOutputStream fos = new FileOutputStream(filename); |  | ||||||
|             fos.write(WaveHeader.RIFF_HEADER.getBytes()); |  | ||||||
|             // little endian |  | ||||||
|             fos.write(new byte[] {(byte) (chunkSize), (byte) (chunkSize >> 8), (byte) (chunkSize >> 16), |  | ||||||
|                             (byte) (chunkSize >> 24)}); |  | ||||||
|             fos.write(WaveHeader.WAVE_HEADER.getBytes()); |  | ||||||
|             fos.write(WaveHeader.FMT_HEADER.getBytes()); |  | ||||||
|             fos.write(new byte[] {(byte) (subChunk1Size), (byte) (subChunk1Size >> 8), (byte) (subChunk1Size >> 16), |  | ||||||
|                             (byte) (subChunk1Size >> 24)}); |  | ||||||
|             fos.write(new byte[] {(byte) (audioFormat), (byte) (audioFormat >> 8)}); |  | ||||||
|             fos.write(new byte[] {(byte) (channels), (byte) (channels >> 8)}); |  | ||||||
|             fos.write(new byte[] {(byte) (sampleRate), (byte) (sampleRate >> 8), (byte) (sampleRate >> 16), |  | ||||||
|                             (byte) (sampleRate >> 24)}); |  | ||||||
|             fos.write(new byte[] {(byte) (byteRate), (byte) (byteRate >> 8), (byte) (byteRate >> 16), |  | ||||||
|                             (byte) (byteRate >> 24)}); |  | ||||||
|             fos.write(new byte[] {(byte) (blockAlign), (byte) (blockAlign >> 8)}); |  | ||||||
|             fos.write(new byte[] {(byte) (bitsPerSample), (byte) (bitsPerSample >> 8)}); |  | ||||||
|             fos.write(WaveHeader.DATA_HEADER.getBytes()); |  | ||||||
|             fos.write(new byte[] {(byte) (subChunk2Size), (byte) (subChunk2Size >> 8), (byte) (subChunk2Size >> 16), |  | ||||||
|                             (byte) (subChunk2Size >> 24)}); |  | ||||||
|             fos.write(wave.getBytes()); |  | ||||||
|             fos.close(); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Wave getWave() { |  | ||||||
|         return wave; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setWave(Wave wave) { |  | ||||||
|         this.wave = wave; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,281 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class WaveHeader { |  | ||||||
| 
 |  | ||||||
|     public static final String RIFF_HEADER = "RIFF"; |  | ||||||
|     public static final String WAVE_HEADER = "WAVE"; |  | ||||||
|     public static final String FMT_HEADER = "fmt "; |  | ||||||
|     public static final String DATA_HEADER = "data"; |  | ||||||
|     public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header |  | ||||||
| 
 |  | ||||||
|     private boolean valid; |  | ||||||
|     private String chunkId; // 4 bytes |  | ||||||
|     private long chunkSize; // unsigned 4 bytes, little endian |  | ||||||
|     private String format; // 4 bytes |  | ||||||
|     private String subChunk1Id; // 4 bytes |  | ||||||
|     private long subChunk1Size; // unsigned 4 bytes, little endian |  | ||||||
|     private int audioFormat; // unsigned 2 bytes, little endian |  | ||||||
|     private int channels; // unsigned 2 bytes, little endian |  | ||||||
|     private long sampleRate; // unsigned 4 bytes, little endian |  | ||||||
|     private long byteRate; // unsigned 4 bytes, little endian |  | ||||||
|     private int blockAlign; // unsigned 2 bytes, little endian |  | ||||||
|     private int bitsPerSample; // unsigned 2 bytes, little endian |  | ||||||
|     private String subChunk2Id; // 4 bytes |  | ||||||
|     private long subChunk2Size; // unsigned 4 bytes, little endian |  | ||||||
| 
 |  | ||||||
|     public WaveHeader() { |  | ||||||
|         // init a 8k 16bit mono wav		 |  | ||||||
|         chunkSize = 36; |  | ||||||
|         subChunk1Size = 16; |  | ||||||
|         audioFormat = 1; |  | ||||||
|         channels = 1; |  | ||||||
|         sampleRate = 8000; |  | ||||||
|         byteRate = 16000; |  | ||||||
|         blockAlign = 2; |  | ||||||
|         bitsPerSample = 16; |  | ||||||
|         subChunk2Size = 0; |  | ||||||
|         valid = true; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public WaveHeader(InputStream inputStream) { |  | ||||||
|         valid = loadHeader(inputStream); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private boolean loadHeader(InputStream inputStream) { |  | ||||||
| 
 |  | ||||||
|         byte[] headerBuffer = new byte[HEADER_BYTE_LENGTH]; |  | ||||||
|         try { |  | ||||||
|             inputStream.read(headerBuffer); |  | ||||||
| 
 |  | ||||||
|             // read header |  | ||||||
|             int pointer = 0; |  | ||||||
|             chunkId = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], |  | ||||||
|                             headerBuffer[pointer++]}); |  | ||||||
|             // little endian |  | ||||||
|             chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff << 24); |  | ||||||
|             format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], |  | ||||||
|                             headerBuffer[pointer++]}); |  | ||||||
|             subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], |  | ||||||
|                             headerBuffer[pointer++], headerBuffer[pointer++]}); |  | ||||||
|             subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; |  | ||||||
|             audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); |  | ||||||
|             channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); |  | ||||||
|             sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; |  | ||||||
|             byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; |  | ||||||
|             blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); |  | ||||||
|             bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); |  | ||||||
|             subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], |  | ||||||
|                             headerBuffer[pointer++], headerBuffer[pointer++]}); |  | ||||||
|             subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 16 |  | ||||||
|                             | (long) (headerBuffer[pointer++] & 0xff) << 24; |  | ||||||
|             // end read header |  | ||||||
| 
 |  | ||||||
|             // the inputStream should be closed outside this method |  | ||||||
| 
 |  | ||||||
|             // dis.close(); |  | ||||||
| 
 |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|             return false; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (bitsPerSample != 8 && bitsPerSample != 16) { |  | ||||||
|             System.err.println("WaveHeader: only supports bitsPerSample 8 or 16"); |  | ||||||
|             return false; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // check the format is support |  | ||||||
|         if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) { |  | ||||||
|             return true; |  | ||||||
|         } else { |  | ||||||
|             System.err.println("WaveHeader: Unsupported header format"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return false; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isValid() { |  | ||||||
|         return valid; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String getChunkId() { |  | ||||||
|         return chunkId; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public long getChunkSize() { |  | ||||||
|         return chunkSize; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String getFormat() { |  | ||||||
|         return format; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String getSubChunk1Id() { |  | ||||||
|         return subChunk1Id; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public long getSubChunk1Size() { |  | ||||||
|         return subChunk1Size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getAudioFormat() { |  | ||||||
|         return audioFormat; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getChannels() { |  | ||||||
|         return channels; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getSampleRate() { |  | ||||||
|         return (int) sampleRate; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getByteRate() { |  | ||||||
|         return (int) byteRate; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getBlockAlign() { |  | ||||||
|         return blockAlign; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getBitsPerSample() { |  | ||||||
|         return bitsPerSample; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String getSubChunk2Id() { |  | ||||||
|         return subChunk2Id; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public long getSubChunk2Size() { |  | ||||||
|         return subChunk2Size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSampleRate(int sampleRate) { |  | ||||||
|         int newSubChunk2Size = (int) (this.subChunk2Size * sampleRate / this.sampleRate); |  | ||||||
|         // if num bytes for each sample is even, the size of newSubChunk2Size also needed to be in even number |  | ||||||
|         if ((bitsPerSample / 8) % 2 == 0) { |  | ||||||
|             if (newSubChunk2Size % 2 != 0) { |  | ||||||
|                 newSubChunk2Size++; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         this.sampleRate = sampleRate; |  | ||||||
|         this.byteRate = sampleRate * bitsPerSample / 8; |  | ||||||
|         this.chunkSize = newSubChunk2Size + 36; |  | ||||||
|         this.subChunk2Size = newSubChunk2Size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setChunkId(String chunkId) { |  | ||||||
|         this.chunkId = chunkId; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setChunkSize(long chunkSize) { |  | ||||||
|         this.chunkSize = chunkSize; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setFormat(String format) { |  | ||||||
|         this.format = format; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSubChunk1Id(String subChunk1Id) { |  | ||||||
|         this.subChunk1Id = subChunk1Id; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSubChunk1Size(long subChunk1Size) { |  | ||||||
|         this.subChunk1Size = subChunk1Size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setAudioFormat(int audioFormat) { |  | ||||||
|         this.audioFormat = audioFormat; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setChannels(int channels) { |  | ||||||
|         this.channels = channels; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setByteRate(long byteRate) { |  | ||||||
|         this.byteRate = byteRate; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setBlockAlign(int blockAlign) { |  | ||||||
|         this.blockAlign = blockAlign; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setBitsPerSample(int bitsPerSample) { |  | ||||||
|         this.bitsPerSample = bitsPerSample; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSubChunk2Id(String subChunk2Id) { |  | ||||||
|         this.subChunk2Id = subChunk2Id; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSubChunk2Size(long subChunk2Size) { |  | ||||||
|         this.subChunk2Size = subChunk2Size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String toString() { |  | ||||||
| 
 |  | ||||||
|         StringBuilder sb = new StringBuilder(); |  | ||||||
|         sb.append("chunkId: " + chunkId); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("chunkSize: " + chunkSize); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("format: " + format); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("subChunk1Id: " + subChunk1Id); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("subChunk1Size: " + subChunk1Size); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("audioFormat: " + audioFormat); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("channels: " + channels); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("sampleRate: " + sampleRate); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("byteRate: " + byteRate); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("blockAlign: " + blockAlign); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("bitsPerSample: " + bitsPerSample); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("subChunk2Id: " + subChunk2Id); |  | ||||||
|         sb.append("\n"); |  | ||||||
|         sb.append("subChunk2Size: " + subChunk2Size); |  | ||||||
|         return sb.toString(); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,81 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.dsp; |  | ||||||
| 
 |  | ||||||
| import org.jtransforms.fft.DoubleFFT_1D; |  | ||||||
| 
 |  | ||||||
| public class FastFourierTransform { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the frequency intensities |  | ||||||
|      * |  | ||||||
|      * @param amplitudes amplitudes of the signal. Format depends on value of complex |  | ||||||
|      * @param complex    if true, amplitudes is assumed to be complex interlaced (re = even, im = odd), if false amplitudes |  | ||||||
|      *                   are assumed to be real valued. |  | ||||||
|      * @return intensities of each frequency unit: mag[frequency_unit]=intensity |  | ||||||
|      */ |  | ||||||
|     public double[] getMagnitudes(double[] amplitudes, boolean complex) { |  | ||||||
| 
 |  | ||||||
|         final int sampleSize = amplitudes.length; |  | ||||||
|         final int nrofFrequencyBins = sampleSize / 2; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         // call the fft and transform the complex numbers |  | ||||||
|         if (complex) { |  | ||||||
|             DoubleFFT_1D fft = new DoubleFFT_1D(nrofFrequencyBins); |  | ||||||
|             fft.complexForward(amplitudes); |  | ||||||
|         } else { |  | ||||||
|             DoubleFFT_1D fft = new DoubleFFT_1D(sampleSize); |  | ||||||
|             fft.realForward(amplitudes); |  | ||||||
|             // amplitudes[1] contains re[sampleSize/2] or im[(sampleSize-1) / 2] (depending on whether sampleSize is odd or even) |  | ||||||
|             // Discard it as it is useless without the other part |  | ||||||
|             // im part dc bin is always 0 for real input |  | ||||||
|             amplitudes[1] = 0; |  | ||||||
|         } |  | ||||||
|         // end call the fft and transform the complex numbers |  | ||||||
| 
 |  | ||||||
|         // even indexes (0,2,4,6,...) are real parts |  | ||||||
|         // odd indexes (1,3,5,7,...) are img parts |  | ||||||
|         double[] mag = new double[nrofFrequencyBins]; |  | ||||||
|         for (int i = 0; i < nrofFrequencyBins; i++) { |  | ||||||
|             final int f = 2 * i; |  | ||||||
|             mag[i] = Math.sqrt(amplitudes[f] * amplitudes[f] + amplitudes[f + 1] * amplitudes[f + 1]); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return mag; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the frequency intensities. Backwards compatible with previous versions w.r.t to number of frequency bins. |  | ||||||
|      * Use getMagnitudes(amplitudes, true) to get all bins. |  | ||||||
|      * |  | ||||||
|      * @param amplitudes complex-valued signal to transform. Even indexes are real and odd indexes are img |  | ||||||
|      * @return intensities of each frequency unit: mag[frequency_unit]=intensity |  | ||||||
|      */ |  | ||||||
|     public double[] getMagnitudes(double[] amplitudes) { |  | ||||||
|         double[] magnitudes = getMagnitudes(amplitudes, true); |  | ||||||
| 
 |  | ||||||
|         double[] halfOfMagnitudes = new double[magnitudes.length/2]; |  | ||||||
|         System.arraycopy(magnitudes, 0,halfOfMagnitudes, 0, halfOfMagnitudes.length); |  | ||||||
|         return halfOfMagnitudes; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,66 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.dsp; |  | ||||||
| 
 |  | ||||||
| public class LinearInterpolation { |  | ||||||
| 
 |  | ||||||
|     public LinearInterpolation() { |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Do interpolation on the samples according to the original and destinated sample rates |  | ||||||
|      *  |  | ||||||
|      * @param oldSampleRate	sample rate of the original samples |  | ||||||
|      * @param newSampleRate	sample rate of the interpolated samples |  | ||||||
|      * @param samples	original samples |  | ||||||
|      * @return interpolated samples |  | ||||||
|      */ |  | ||||||
|     public short[] interpolate(int oldSampleRate, int newSampleRate, short[] samples) { |  | ||||||
| 
 |  | ||||||
|         if (oldSampleRate == newSampleRate) { |  | ||||||
|             return samples; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         int newLength = Math.round(((float) samples.length / oldSampleRate * newSampleRate)); |  | ||||||
|         float lengthMultiplier = (float) newLength / samples.length; |  | ||||||
|         short[] interpolatedSamples = new short[newLength]; |  | ||||||
| 
 |  | ||||||
|         // interpolate the value by the linear equation y=mx+c         |  | ||||||
|         for (int i = 0; i < newLength; i++) { |  | ||||||
| 
 |  | ||||||
|             // get the nearest positions for the interpolated point |  | ||||||
|             float currentPosition = i / lengthMultiplier; |  | ||||||
|             int nearestLeftPosition = (int) currentPosition; |  | ||||||
|             int nearestRightPosition = nearestLeftPosition + 1; |  | ||||||
|             if (nearestRightPosition >= samples.length) { |  | ||||||
|                 nearestRightPosition = samples.length - 1; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             float slope = samples[nearestRightPosition] - samples[nearestLeftPosition]; // delta x is 1 |  | ||||||
|             float positionFromLeft = currentPosition - nearestLeftPosition; |  | ||||||
| 
 |  | ||||||
|             interpolatedSamples[i] = (short) (slope * positionFromLeft + samples[nearestLeftPosition]); // y=mx+c |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return interpolatedSamples; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,84 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.dsp; |  | ||||||
| 
 |  | ||||||
| public class Resampler { |  | ||||||
| 
 |  | ||||||
|     public Resampler() {} |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Do resampling. Currently the amplitude is stored by short such that maximum bitsPerSample is 16 (bytePerSample is 2) |  | ||||||
|      *  |  | ||||||
|      * @param sourceData	The source data in bytes |  | ||||||
|      * @param bitsPerSample	How many bits represents one sample (currently supports max. bitsPerSample=16)  |  | ||||||
|      * @param sourceRate	Sample rate of the source data |  | ||||||
|      * @param targetRate	Sample rate of the target data |  | ||||||
|      * @return re-sampled data |  | ||||||
|      */ |  | ||||||
|     public byte[] reSample(byte[] sourceData, int bitsPerSample, int sourceRate, int targetRate) { |  | ||||||
| 
 |  | ||||||
|         // make the bytes to amplitudes first |  | ||||||
|         int bytePerSample = bitsPerSample / 8; |  | ||||||
|         int numSamples = sourceData.length / bytePerSample; |  | ||||||
|         short[] amplitudes = new short[numSamples]; // 16 bit, use a short to store |  | ||||||
| 
 |  | ||||||
|         int pointer = 0; |  | ||||||
|         for (int i = 0; i < numSamples; i++) { |  | ||||||
|             short amplitude = 0; |  | ||||||
|             for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) { |  | ||||||
|                 // little endian |  | ||||||
|                 amplitude |= (short) ((sourceData[pointer++] & 0xFF) << (byteNumber * 8)); |  | ||||||
|             } |  | ||||||
|             amplitudes[i] = amplitude; |  | ||||||
|         } |  | ||||||
|         // end make the amplitudes |  | ||||||
| 
 |  | ||||||
|         // do interpolation |  | ||||||
|         LinearInterpolation reSample = new LinearInterpolation(); |  | ||||||
|         short[] targetSample = reSample.interpolate(sourceRate, targetRate, amplitudes); |  | ||||||
|         int targetLength = targetSample.length; |  | ||||||
|         // end do interpolation |  | ||||||
| 
 |  | ||||||
|         // TODO: Remove the high frequency signals with a digital filter, leaving a signal containing only half-sample-rated frequency information, but still sampled at a rate of target sample rate. Usually FIR is used |  | ||||||
| 
 |  | ||||||
|         // end resample the amplitudes |  | ||||||
| 
 |  | ||||||
|         // convert the amplitude to bytes |  | ||||||
|         byte[] bytes; |  | ||||||
|         if (bytePerSample == 1) { |  | ||||||
|             bytes = new byte[targetLength]; |  | ||||||
|             for (int i = 0; i < targetLength; i++) { |  | ||||||
|                 bytes[i] = (byte) targetSample[i]; |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             // suppose bytePerSample==2 |  | ||||||
|             bytes = new byte[targetLength * 2]; |  | ||||||
|             for (int i = 0; i < targetSample.length; i++) { |  | ||||||
|                 // little endian			 |  | ||||||
|                 bytes[i * 2] = (byte) (targetSample[i] & 0xff); |  | ||||||
|                 bytes[i * 2 + 1] = (byte) ((targetSample[i] >> 8) & 0xff); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         // end convert the amplitude to bytes |  | ||||||
| 
 |  | ||||||
|         return bytes; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,95 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.dsp; |  | ||||||
| 
 |  | ||||||
| public class WindowFunction { |  | ||||||
| 
 |  | ||||||
|     public static final int RECTANGULAR = 0; |  | ||||||
|     public static final int BARTLETT = 1; |  | ||||||
|     public static final int HANNING = 2; |  | ||||||
|     public static final int HAMMING = 3; |  | ||||||
|     public static final int BLACKMAN = 4; |  | ||||||
| 
 |  | ||||||
|     int windowType = 0; // defaults to rectangular window |  | ||||||
| 
 |  | ||||||
|     public WindowFunction() {} |  | ||||||
| 
 |  | ||||||
|     public void setWindowType(int wt) { |  | ||||||
|         windowType = wt; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setWindowType(String w) { |  | ||||||
|         if (w.toUpperCase().equals("RECTANGULAR")) |  | ||||||
|             windowType = RECTANGULAR; |  | ||||||
|         if (w.toUpperCase().equals("BARTLETT")) |  | ||||||
|             windowType = BARTLETT; |  | ||||||
|         if (w.toUpperCase().equals("HANNING")) |  | ||||||
|             windowType = HANNING; |  | ||||||
|         if (w.toUpperCase().equals("HAMMING")) |  | ||||||
|             windowType = HAMMING; |  | ||||||
|         if (w.toUpperCase().equals("BLACKMAN")) |  | ||||||
|             windowType = BLACKMAN; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getWindowType() { |  | ||||||
|         return windowType; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Generate a window |  | ||||||
|      *  |  | ||||||
|      * @param nSamples	size of the window |  | ||||||
|      * @return	window in array |  | ||||||
|      */ |  | ||||||
|     public double[] generate(int nSamples) { |  | ||||||
|         // generate nSamples window function values |  | ||||||
|         // for index values 0 .. nSamples - 1 |  | ||||||
|         int m = nSamples / 2; |  | ||||||
|         double r; |  | ||||||
|         double pi = Math.PI; |  | ||||||
|         double[] w = new double[nSamples]; |  | ||||||
|         switch (windowType) { |  | ||||||
|             case BARTLETT: // Bartlett (triangular) window |  | ||||||
|                 for (int n = 0; n < nSamples; n++) |  | ||||||
|                     w[n] = 1.0f - Math.abs(n - m) / m; |  | ||||||
|                 break; |  | ||||||
|             case HANNING: // Hanning window |  | ||||||
|                 r = pi / (m + 1); |  | ||||||
|                 for (int n = -m; n < m; n++) |  | ||||||
|                     w[m + n] = 0.5f + 0.5f * Math.cos(n * r); |  | ||||||
|                 break; |  | ||||||
|             case HAMMING: // Hamming window |  | ||||||
|                 r = pi / m; |  | ||||||
|                 for (int n = -m; n < m; n++) |  | ||||||
|                     w[m + n] = 0.54f + 0.46f * Math.cos(n * r); |  | ||||||
|                 break; |  | ||||||
|             case BLACKMAN: // Blackman window |  | ||||||
|                 r = pi / m; |  | ||||||
|                 for (int n = -m; n < m; n++) |  | ||||||
|                     w[m + n] = 0.42f + 0.5f * Math.cos(n * r) + 0.08f * Math.cos(2 * n * r); |  | ||||||
|                 break; |  | ||||||
|             default: // Rectangular window function |  | ||||||
|                 for (int n = 0; n < nSamples; n++) |  | ||||||
|                     w[n] = 1.0f; |  | ||||||
|         } |  | ||||||
|         return w; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,21 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.dsp; |  | ||||||
| @ -1,67 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.extension; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.audio.Wave; |  | ||||||
| 
 |  | ||||||
| public class NormalizedSampleAmplitudes { |  | ||||||
| 
 |  | ||||||
|     private Wave wave; |  | ||||||
|     private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame |  | ||||||
| 
 |  | ||||||
|     public NormalizedSampleAmplitudes(Wave wave) { |  | ||||||
|         this.wave = wave; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      *  |  | ||||||
|      * Get normalized amplitude of each frame |  | ||||||
|      *  |  | ||||||
|      * @return	array of normalized amplitudes(signed 16 bit): normalizedAmplitudes[frame]=amplitude |  | ||||||
|      */ |  | ||||||
|     public double[] getNormalizedAmplitudes() { |  | ||||||
| 
 |  | ||||||
|         if (normalizedAmplitudes == null) { |  | ||||||
| 
 |  | ||||||
|             boolean signed = true; |  | ||||||
| 
 |  | ||||||
|             // usually 8bit is unsigned |  | ||||||
|             if (wave.getWaveHeader().getBitsPerSample() == 8) { |  | ||||||
|                 signed = false; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             short[] amplitudes = wave.getSampleAmplitudes(); |  | ||||||
|             int numSamples = amplitudes.length; |  | ||||||
|             int maxAmplitude = 1 << (wave.getWaveHeader().getBitsPerSample() - 1); |  | ||||||
| 
 |  | ||||||
|             if (!signed) { // one more bit for unsigned value |  | ||||||
|                 maxAmplitude <<= 1; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             normalizedAmplitudes = new double[numSamples]; |  | ||||||
|             for (int i = 0; i < numSamples; i++) { |  | ||||||
|                 normalizedAmplitudes[i] = (double) amplitudes[i] / maxAmplitude; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return normalizedAmplitudes; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,214 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.extension; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.audio.Wave; |  | ||||||
| import org.datavec.audio.dsp.FastFourierTransform; |  | ||||||
| import org.datavec.audio.dsp.WindowFunction; |  | ||||||
| 
 |  | ||||||
| public class Spectrogram { |  | ||||||
| 
 |  | ||||||
|     public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024; |  | ||||||
|     public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping |  | ||||||
| 
 |  | ||||||
|     private Wave wave; |  | ||||||
|     private double[][] spectrogram; // relative spectrogram |  | ||||||
|     private double[][] absoluteSpectrogram; // absolute spectrogram |  | ||||||
|     private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2 |  | ||||||
|     private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping |  | ||||||
|     private int numFrames; // number of frames of the spectrogram |  | ||||||
|     private int framesPerSecond; // frame per second of the spectrogram |  | ||||||
|     private int numFrequencyUnit; // number of y-axis unit |  | ||||||
|     private double unitFrequency; // frequency per y-axis unit |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      * |  | ||||||
|      * @param wave |  | ||||||
|      */ |  | ||||||
|     public Spectrogram(Wave wave) { |  | ||||||
|         this.wave = wave; |  | ||||||
|         // default |  | ||||||
|         this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE; |  | ||||||
|         this.overlapFactor = SPECTROGRAM_DEFAULT_OVERLAP_FACTOR; |  | ||||||
|         buildSpectrogram(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      * |  | ||||||
|      * @param wave |  | ||||||
|      * @param fftSampleSize	number of sample in fft, the value needed to be a number to power of 2 |  | ||||||
|      * @param overlapFactor	1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping |  | ||||||
|      */ |  | ||||||
|     public Spectrogram(Wave wave, int fftSampleSize, int overlapFactor) { |  | ||||||
|         this.wave = wave; |  | ||||||
| 
 |  | ||||||
|         if (Integer.bitCount(fftSampleSize) == 1) { |  | ||||||
|             this.fftSampleSize = fftSampleSize; |  | ||||||
|         } else { |  | ||||||
|             System.err.print("The input number must be a power of 2"); |  | ||||||
|             this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         this.overlapFactor = overlapFactor; |  | ||||||
| 
 |  | ||||||
|         buildSpectrogram(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Build spectrogram |  | ||||||
|      */ |  | ||||||
|     private void buildSpectrogram() { |  | ||||||
| 
 |  | ||||||
|         short[] amplitudes = wave.getSampleAmplitudes(); |  | ||||||
|         int numSamples = amplitudes.length; |  | ||||||
| 
 |  | ||||||
|         int pointer = 0; |  | ||||||
|         // overlapping |  | ||||||
|         if (overlapFactor > 1) { |  | ||||||
|             int numOverlappedSamples = numSamples * overlapFactor; |  | ||||||
|             int backSamples = fftSampleSize * (overlapFactor - 1) / overlapFactor; |  | ||||||
|             short[] overlapAmp = new short[numOverlappedSamples]; |  | ||||||
|             pointer = 0; |  | ||||||
|             for (int i = 0; i < amplitudes.length; i++) { |  | ||||||
|                 overlapAmp[pointer++] = amplitudes[i]; |  | ||||||
|                 if (pointer % fftSampleSize == 0) { |  | ||||||
|                     // overlap |  | ||||||
|                     i -= backSamples; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             numSamples = numOverlappedSamples; |  | ||||||
|             amplitudes = overlapAmp; |  | ||||||
|         } |  | ||||||
|         // end overlapping |  | ||||||
| 
 |  | ||||||
|         numFrames = numSamples / fftSampleSize; |  | ||||||
|         framesPerSecond = (int) (numFrames / wave.length()); |  | ||||||
| 
 |  | ||||||
|         // set signals for fft |  | ||||||
|         WindowFunction window = new WindowFunction(); |  | ||||||
|         window.setWindowType("Hamming"); |  | ||||||
|         double[] win = window.generate(fftSampleSize); |  | ||||||
| 
 |  | ||||||
|         double[][] signals = new double[numFrames][]; |  | ||||||
|         for (int f = 0; f < numFrames; f++) { |  | ||||||
|             signals[f] = new double[fftSampleSize]; |  | ||||||
|             int startSample = f * fftSampleSize; |  | ||||||
|             for (int n = 0; n < fftSampleSize; n++) { |  | ||||||
|                 signals[f][n] = amplitudes[startSample + n] * win[n]; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         // end set signals for fft |  | ||||||
| 
 |  | ||||||
|         absoluteSpectrogram = new double[numFrames][]; |  | ||||||
|         // for each frame in signals, do fft on it |  | ||||||
|         FastFourierTransform fft = new FastFourierTransform(); |  | ||||||
|         for (int i = 0; i < numFrames; i++) { |  | ||||||
|             absoluteSpectrogram[i] = fft.getMagnitudes(signals[i], false); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (absoluteSpectrogram.length > 0) { |  | ||||||
| 
 |  | ||||||
|             numFrequencyUnit = absoluteSpectrogram[0].length; |  | ||||||
|             unitFrequency = (double) wave.getWaveHeader().getSampleRate() / 2 / numFrequencyUnit; // frequency could be caught within the half of nSamples according to Nyquist theory |  | ||||||
| 
 |  | ||||||
|             // normalization of absoultSpectrogram |  | ||||||
|             spectrogram = new double[numFrames][numFrequencyUnit]; |  | ||||||
| 
 |  | ||||||
|             // set max and min amplitudes |  | ||||||
|             double maxAmp = Double.MIN_VALUE; |  | ||||||
|             double minAmp = Double.MAX_VALUE; |  | ||||||
|             for (int i = 0; i < numFrames; i++) { |  | ||||||
|                 for (int j = 0; j < numFrequencyUnit; j++) { |  | ||||||
|                     if (absoluteSpectrogram[i][j] > maxAmp) { |  | ||||||
|                         maxAmp = absoluteSpectrogram[i][j]; |  | ||||||
|                     } else if (absoluteSpectrogram[i][j] < minAmp) { |  | ||||||
|                         minAmp = absoluteSpectrogram[i][j]; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             // end set max and min amplitudes |  | ||||||
| 
 |  | ||||||
|             // normalization |  | ||||||
|             // avoiding divided by zero  |  | ||||||
|             double minValidAmp = 0.00000000001F; |  | ||||||
|             if (minAmp == 0) { |  | ||||||
|                 minAmp = minValidAmp; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             double diff = Math.log10(maxAmp / minAmp); // perceptual difference |  | ||||||
|             for (int i = 0; i < numFrames; i++) { |  | ||||||
|                 for (int j = 0; j < numFrequencyUnit; j++) { |  | ||||||
|                     if (absoluteSpectrogram[i][j] < minValidAmp) { |  | ||||||
|                         spectrogram[i][j] = 0; |  | ||||||
|                     } else { |  | ||||||
|                         spectrogram[i][j] = (Math.log10(absoluteSpectrogram[i][j] / minAmp)) / diff; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             // end normalization |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get spectrogram: spectrogram[time][frequency]=intensity |  | ||||||
|      * |  | ||||||
|      * @return	logarithm normalized spectrogram |  | ||||||
|      */ |  | ||||||
|     public double[][] getNormalizedSpectrogramData() { |  | ||||||
|         return spectrogram; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get spectrogram: spectrogram[time][frequency]=intensity |  | ||||||
|      * |  | ||||||
|      * @return	absolute spectrogram |  | ||||||
|      */ |  | ||||||
|     public double[][] getAbsoluteSpectrogramData() { |  | ||||||
|         return absoluteSpectrogram; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumFrames() { |  | ||||||
|         return numFrames; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getFramesPerSecond() { |  | ||||||
|         return framesPerSecond; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumFrequencyUnit() { |  | ||||||
|         return numFrequencyUnit; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getUnitFrequency() { |  | ||||||
|         return unitFrequency; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getFftSampleSize() { |  | ||||||
|         return fftSampleSize; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getOverlapFactor() { |  | ||||||
|         return overlapFactor; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,272 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.datavec.audio.Wave; |  | ||||||
| import org.datavec.audio.WaveHeader; |  | ||||||
| import org.datavec.audio.dsp.Resampler; |  | ||||||
| import org.datavec.audio.extension.Spectrogram; |  | ||||||
| import org.datavec.audio.processor.TopManyPointsProcessorChain; |  | ||||||
| import org.datavec.audio.properties.FingerprintProperties; |  | ||||||
| 
 |  | ||||||
| import java.io.FileInputStream; |  | ||||||
| import java.io.FileOutputStream; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.util.Iterator; |  | ||||||
| import java.util.LinkedList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class FingerprintManager { |  | ||||||
| 
 |  | ||||||
|     private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); |  | ||||||
|     private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame(); |  | ||||||
|     private int overlapFactor = fingerprintProperties.getOverlapFactor(); |  | ||||||
|     private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame(); |  | ||||||
|     private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      */ |  | ||||||
|     public FingerprintManager() { |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Extract fingerprint from Wave object |  | ||||||
|      *  |  | ||||||
|      * @param wave	Wave Object to be extracted fingerprint |  | ||||||
|      * @return fingerprint in bytes |  | ||||||
|      */ |  | ||||||
|     public byte[] extractFingerprint(Wave wave) { |  | ||||||
| 
 |  | ||||||
|         int[][] coordinates; // coordinates[x][0..3]=y0..y3 |  | ||||||
|         byte[] fingerprint = new byte[0]; |  | ||||||
| 
 |  | ||||||
|         // resample to target rate |  | ||||||
|         Resampler resampler = new Resampler(); |  | ||||||
|         int sourceRate = wave.getWaveHeader().getSampleRate(); |  | ||||||
|         int targetRate = fingerprintProperties.getSampleRate(); |  | ||||||
| 
 |  | ||||||
|         byte[] resampledWaveData = resampler.reSample(wave.getBytes(), wave.getWaveHeader().getBitsPerSample(), |  | ||||||
|                         sourceRate, targetRate); |  | ||||||
| 
 |  | ||||||
|         // update the wave header |  | ||||||
|         WaveHeader resampledWaveHeader = wave.getWaveHeader(); |  | ||||||
|         resampledWaveHeader.setSampleRate(targetRate); |  | ||||||
| 
 |  | ||||||
|         // make resampled wave |  | ||||||
|         Wave resampledWave = new Wave(resampledWaveHeader, resampledWaveData); |  | ||||||
|         // end resample to target rate |  | ||||||
| 
 |  | ||||||
|         // get spectrogram's data |  | ||||||
|         Spectrogram spectrogram = resampledWave.getSpectrogram(sampleSizePerFrame, overlapFactor); |  | ||||||
|         double[][] spectorgramData = spectrogram.getNormalizedSpectrogramData(); |  | ||||||
| 
 |  | ||||||
|         List<Integer>[] pointsLists = getRobustPointList(spectorgramData); |  | ||||||
|         int numFrames = pointsLists.length; |  | ||||||
| 
 |  | ||||||
|         // prepare fingerprint bytes |  | ||||||
|         coordinates = new int[numFrames][numRobustPointsPerFrame]; |  | ||||||
| 
 |  | ||||||
|         for (int x = 0; x < numFrames; x++) { |  | ||||||
|             if (pointsLists[x].size() == numRobustPointsPerFrame) { |  | ||||||
|                 Iterator<Integer> pointsListsIterator = pointsLists[x].iterator(); |  | ||||||
|                 for (int y = 0; y < numRobustPointsPerFrame; y++) { |  | ||||||
|                     coordinates[x][y] = pointsListsIterator.next(); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 // use -1 to fill the empty byte |  | ||||||
|                 for (int y = 0; y < numRobustPointsPerFrame; y++) { |  | ||||||
|                     coordinates[x][y] = -1; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         // end make fingerprint |  | ||||||
| 
 |  | ||||||
|         // for each valid coordinate, append with its intensity |  | ||||||
|         List<Byte> byteList = new LinkedList<Byte>(); |  | ||||||
|         for (int i = 0; i < numFrames; i++) { |  | ||||||
|             for (int j = 0; j < numRobustPointsPerFrame; j++) { |  | ||||||
|                 if (coordinates[i][j] != -1) { |  | ||||||
|                     // first 2 bytes is x |  | ||||||
|                     byteList.add((byte) (i >> 8)); |  | ||||||
|                     byteList.add((byte) i); |  | ||||||
| 
 |  | ||||||
|                     // next 2 bytes is y |  | ||||||
|                     int y = coordinates[i][j]; |  | ||||||
|                     byteList.add((byte) (y >> 8)); |  | ||||||
|                     byteList.add((byte) y); |  | ||||||
| 
 |  | ||||||
|                     // next 4 bytes is intensity |  | ||||||
|                     int intensity = (int) (spectorgramData[i][y] * Integer.MAX_VALUE); // spectorgramData is ranged from 0~1 |  | ||||||
|                     byteList.add((byte) (intensity >> 24)); |  | ||||||
|                     byteList.add((byte) (intensity >> 16)); |  | ||||||
|                     byteList.add((byte) (intensity >> 8)); |  | ||||||
|                     byteList.add((byte) intensity); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         // end for each valid coordinate, append with its intensity |  | ||||||
| 
 |  | ||||||
|         fingerprint = new byte[byteList.size()]; |  | ||||||
|         Iterator<Byte> byteListIterator = byteList.iterator(); |  | ||||||
|         int pointer = 0; |  | ||||||
|         while (byteListIterator.hasNext()) { |  | ||||||
|             fingerprint[pointer++] = byteListIterator.next(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return fingerprint; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get bytes from fingerprint file |  | ||||||
|      *  |  | ||||||
|      * @param fingerprintFile	fingerprint filename |  | ||||||
|      * @return fingerprint in bytes |  | ||||||
|      */ |  | ||||||
|     public byte[] getFingerprintFromFile(String fingerprintFile) { |  | ||||||
|         byte[] fingerprint = null; |  | ||||||
|         try { |  | ||||||
|             InputStream fis = new FileInputStream(fingerprintFile); |  | ||||||
|             fingerprint = getFingerprintFromInputStream(fis); |  | ||||||
|             fis.close(); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|         } |  | ||||||
|         return fingerprint; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get bytes from fingerprint inputstream |  | ||||||
|      *  |  | ||||||
|      * @param inputStream	fingerprint inputstream |  | ||||||
|      * @return fingerprint in bytes |  | ||||||
|      */ |  | ||||||
|     public byte[] getFingerprintFromInputStream(InputStream inputStream) { |  | ||||||
|         byte[] fingerprint = null; |  | ||||||
|         try { |  | ||||||
|             fingerprint = new byte[inputStream.available()]; |  | ||||||
|             inputStream.read(fingerprint); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|         } |  | ||||||
|         return fingerprint; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Save fingerprint to a file |  | ||||||
|      *  |  | ||||||
|      * @param fingerprint	fingerprint bytes |  | ||||||
|      * @param filename		fingerprint filename |  | ||||||
|      * @see	FingerprintManager file saved |  | ||||||
|      */ |  | ||||||
|     public void saveFingerprintAsFile(byte[] fingerprint, String filename) { |  | ||||||
| 
 |  | ||||||
|         FileOutputStream fileOutputStream; |  | ||||||
|         try { |  | ||||||
|             fileOutputStream = new FileOutputStream(filename); |  | ||||||
|             fileOutputStream.write(fingerprint); |  | ||||||
|             fileOutputStream.close(); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // robustLists[x]=y1,y2,y3,... |  | ||||||
|     private List<Integer>[] getRobustPointList(double[][] spectrogramData) { |  | ||||||
| 
 |  | ||||||
|         int numX = spectrogramData.length; |  | ||||||
|         int numY = spectrogramData[0].length; |  | ||||||
| 
 |  | ||||||
|         double[][] allBanksIntensities = new double[numX][numY]; |  | ||||||
|         int bandwidthPerBank = numY / numFilterBanks; |  | ||||||
| 
 |  | ||||||
|         for (int b = 0; b < numFilterBanks; b++) { |  | ||||||
| 
 |  | ||||||
|             double[][] bankIntensities = new double[numX][bandwidthPerBank]; |  | ||||||
| 
 |  | ||||||
|             for (int i = 0; i < numX; i++) { |  | ||||||
|                 System.arraycopy(spectrogramData[i], b * bandwidthPerBank, bankIntensities[i], 0, bandwidthPerBank); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             // get the most robust point in each filter bank |  | ||||||
|             TopManyPointsProcessorChain processorChain = new TopManyPointsProcessorChain(bankIntensities, 1); |  | ||||||
|             double[][] processedIntensities = processorChain.getIntensities(); |  | ||||||
| 
 |  | ||||||
|             for (int i = 0; i < numX; i++) { |  | ||||||
|                 System.arraycopy(processedIntensities[i], 0, allBanksIntensities[i], b * bandwidthPerBank, |  | ||||||
|                                 bandwidthPerBank); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         List<int[]> robustPointList = new LinkedList<int[]>(); |  | ||||||
| 
 |  | ||||||
|         // find robust points |  | ||||||
|         for (int i = 0; i < allBanksIntensities.length; i++) { |  | ||||||
|             for (int j = 0; j < allBanksIntensities[i].length; j++) { |  | ||||||
|                 if (allBanksIntensities[i][j] > 0) { |  | ||||||
| 
 |  | ||||||
|                     int[] point = new int[] {i, j}; |  | ||||||
|                     //System.out.println(i+","+frequency); |  | ||||||
|                     robustPointList.add(point); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         // end find robust points |  | ||||||
| 
 |  | ||||||
|         List<Integer>[] robustLists = new LinkedList[spectrogramData.length]; |  | ||||||
|         for (int i = 0; i < robustLists.length; i++) { |  | ||||||
|             robustLists[i] = new LinkedList<>(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // robustLists[x]=y1,y2,y3,... |  | ||||||
|         for (int[] coor : robustPointList) { |  | ||||||
|             robustLists[coor[0]].add(coor[1]); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // return the list per frame |  | ||||||
|         return robustLists; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Number of frames in a fingerprint |  | ||||||
|      * Each frame lengths 8 bytes |  | ||||||
|      * Usually there is more than one point in each frame, so it cannot simply divide the bytes length by 8 |  | ||||||
|      * Last 8 byte of thisFingerprint is the last frame of this wave |  | ||||||
|      * First 2 byte of the last 8 byte is the x position of this wave, i.e. (number_of_frames-1) of this wave	  |  | ||||||
|      *  |  | ||||||
|      * @param fingerprint	fingerprint bytes |  | ||||||
|      * @return number of frames of the fingerprint |  | ||||||
|      */ |  | ||||||
|     public static int getNumFrames(byte[] fingerprint) { |  | ||||||
| 
 |  | ||||||
|         if (fingerprint.length < 8) { |  | ||||||
|             return 0; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // get the last x-coordinate (length-8&length-7)bytes from fingerprint |  | ||||||
|         return ((fingerprint[fingerprint.length - 8] & 0xff) << 8 | (fingerprint[fingerprint.length - 7] & 0xff)) + 1; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,107 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.audio.properties.FingerprintProperties; |  | ||||||
| 
 |  | ||||||
| public class FingerprintSimilarity { |  | ||||||
| 
 |  | ||||||
|     private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); |  | ||||||
|     private int mostSimilarFramePosition; |  | ||||||
|     private float score; |  | ||||||
|     private float similarity; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      */ |  | ||||||
|     public FingerprintSimilarity() { |  | ||||||
|         mostSimilarFramePosition = Integer.MIN_VALUE; |  | ||||||
|         score = -1; |  | ||||||
|         similarity = -1; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the most similar position in terms of frame number |  | ||||||
|      *  |  | ||||||
|      * @return most similar frame position |  | ||||||
|      */ |  | ||||||
|     public int getMostSimilarFramePosition() { |  | ||||||
|         return mostSimilarFramePosition; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Set the most similar position in terms of frame number |  | ||||||
|      *  |  | ||||||
|      * @param mostSimilarFramePosition |  | ||||||
|      */ |  | ||||||
|     public void setMostSimilarFramePosition(int mostSimilarFramePosition) { |  | ||||||
|         this.mostSimilarFramePosition = mostSimilarFramePosition; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the similarity of the fingerprints |  | ||||||
|      * similarity from 0~1, which 0 means no similar feature is found and 1 means in average there is at least one match in every frame |  | ||||||
|      *  |  | ||||||
|      * @return fingerprints similarity |  | ||||||
|      */ |  | ||||||
|     public float getSimilarity() { |  | ||||||
|         return similarity; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Set the similarity of the fingerprints |  | ||||||
|      *  |  | ||||||
|      * @param similarity similarity |  | ||||||
|      */ |  | ||||||
|     public void setSimilarity(float similarity) { |  | ||||||
|         this.similarity = similarity; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the similarity score of the fingerprints |  | ||||||
|      * Number of features found in the fingerprints per frame  |  | ||||||
|      *  |  | ||||||
|      * @return fingerprints similarity score |  | ||||||
|      */ |  | ||||||
|     public float getScore() { |  | ||||||
|         return score; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Set the similarity score of the fingerprints |  | ||||||
|      *  |  | ||||||
|      * @param score |  | ||||||
|      */ |  | ||||||
|     public void setScore(float score) { |  | ||||||
|         this.score = score; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the most similar position in terms of time in second |  | ||||||
|      *  |  | ||||||
|      * @return most similar starting time |  | ||||||
|      */ |  | ||||||
|     public float getsetMostSimilarTimePosition() { |  | ||||||
|         return (float) mostSimilarFramePosition / fingerprintProperties.getNumRobustPointsPerFrame() |  | ||||||
|                         / fingerprintProperties.getFps(); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,134 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| import java.util.HashMap; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class FingerprintSimilarityComputer { |  | ||||||
| 
 |  | ||||||
|     private FingerprintSimilarity fingerprintSimilarity; |  | ||||||
|     byte[] fingerprint1, fingerprint2; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor, ready to compute the similarity of two fingerprints |  | ||||||
|      *  |  | ||||||
|      * @param fingerprint1 |  | ||||||
|      * @param fingerprint2 |  | ||||||
|      */ |  | ||||||
|     public FingerprintSimilarityComputer(byte[] fingerprint1, byte[] fingerprint2) { |  | ||||||
| 
 |  | ||||||
|         this.fingerprint1 = fingerprint1; |  | ||||||
|         this.fingerprint2 = fingerprint2; |  | ||||||
| 
 |  | ||||||
|         fingerprintSimilarity = new FingerprintSimilarity(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get fingerprint similarity of inout fingerprints |  | ||||||
|      *  |  | ||||||
|      * @return fingerprint similarity object |  | ||||||
|      */ |  | ||||||
|     public FingerprintSimilarity getFingerprintsSimilarity() { |  | ||||||
|         HashMap<Integer, Integer> offset_Score_Table = new HashMap<>(); // offset_Score_Table<offset,count> |  | ||||||
|         int numFrames; |  | ||||||
|         float score = 0; |  | ||||||
|         int mostSimilarFramePosition = Integer.MIN_VALUE; |  | ||||||
| 
 |  | ||||||
|         // one frame may contain several points, use the shorter one be the denominator |  | ||||||
|         if (fingerprint1.length > fingerprint2.length) { |  | ||||||
|             numFrames = FingerprintManager.getNumFrames(fingerprint2); |  | ||||||
|         } else { |  | ||||||
|             numFrames = FingerprintManager.getNumFrames(fingerprint1); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // get the pairs |  | ||||||
|         PairManager pairManager = new PairManager(); |  | ||||||
|         HashMap<Integer, List<Integer>> this_Pair_PositionList_Table = |  | ||||||
|                         pairManager.getPair_PositionList_Table(fingerprint1); |  | ||||||
|         HashMap<Integer, List<Integer>> compareWave_Pair_PositionList_Table = |  | ||||||
|                         pairManager.getPair_PositionList_Table(fingerprint2); |  | ||||||
| 
 |  | ||||||
|         for (Integer compareWaveHashNumber : compareWave_Pair_PositionList_Table.keySet()) { |  | ||||||
|             // if the compareWaveHashNumber doesn't exist in both tables, no need to compare |  | ||||||
|             if (!this_Pair_PositionList_Table.containsKey(compareWaveHashNumber) |  | ||||||
|                             || !compareWave_Pair_PositionList_Table.containsKey(compareWaveHashNumber)) { |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             // for each compare hash number, get the positions |  | ||||||
|             List<Integer> wavePositionList = this_Pair_PositionList_Table.get(compareWaveHashNumber); |  | ||||||
|             List<Integer> compareWavePositionList = compareWave_Pair_PositionList_Table.get(compareWaveHashNumber); |  | ||||||
| 
 |  | ||||||
|             for (Integer thisPosition : wavePositionList) { |  | ||||||
|                 for (Integer compareWavePosition : compareWavePositionList) { |  | ||||||
|                     int offset = thisPosition - compareWavePosition; |  | ||||||
|                     if (offset_Score_Table.containsKey(offset)) { |  | ||||||
|                         offset_Score_Table.put(offset, offset_Score_Table.get(offset) + 1); |  | ||||||
|                     } else { |  | ||||||
|                         offset_Score_Table.put(offset, 1); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // map rank |  | ||||||
|         MapRank mapRank = new MapRankInteger(offset_Score_Table, false); |  | ||||||
| 
 |  | ||||||
|         // get the most similar positions and scores	 |  | ||||||
|         List<Integer> orderedKeyList = mapRank.getOrderedKeyList(100, true); |  | ||||||
|         if (orderedKeyList.size() > 0) { |  | ||||||
|             int key = orderedKeyList.get(0); |  | ||||||
|             // get the highest score position |  | ||||||
|             mostSimilarFramePosition = key; |  | ||||||
|             score = offset_Score_Table.get(key); |  | ||||||
| 
 |  | ||||||
|             // accumulate the scores from neighbours |  | ||||||
|             if (offset_Score_Table.containsKey(key - 1)) { |  | ||||||
|                 score += offset_Score_Table.get(key - 1) / 2; |  | ||||||
|             } |  | ||||||
|             if (offset_Score_Table.containsKey(key + 1)) { |  | ||||||
|                 score += offset_Score_Table.get(key + 1) / 2; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         /* |  | ||||||
|         Iterator<Integer> orderedKeyListIterator=orderedKeyList.iterator(); |  | ||||||
|         while (orderedKeyListIterator.hasNext()){ |  | ||||||
|         	int offset=orderedKeyListIterator.next();					 |  | ||||||
|         	System.out.println(offset+": "+offset_Score_Table.get(offset)); |  | ||||||
|         } |  | ||||||
|         */ |  | ||||||
| 
 |  | ||||||
|         score /= numFrames; |  | ||||||
|         float similarity = score; |  | ||||||
|         // similarity >1 means in average there is at least one match in every frame |  | ||||||
|         if (similarity > 1) { |  | ||||||
|             similarity = 1; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         fingerprintSimilarity.setMostSimilarFramePosition(mostSimilarFramePosition); |  | ||||||
|         fingerprintSimilarity.setScore(score); |  | ||||||
|         fingerprintSimilarity.setSimilarity(similarity); |  | ||||||
| 
 |  | ||||||
|         return fingerprintSimilarity; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,27 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public interface MapRank { |  | ||||||
|     public List getOrderedKeyList(int numKeys, boolean sharpLimit); |  | ||||||
| } |  | ||||||
| @ -1,179 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| import java.util.Map.Entry; |  | ||||||
| 
 |  | ||||||
| public class MapRankDouble implements MapRank { |  | ||||||
| 
 |  | ||||||
|     private Map map; |  | ||||||
|     private boolean acsending = true; |  | ||||||
| 
 |  | ||||||
|     public MapRankDouble(Map<?, Double> map, boolean acsending) { |  | ||||||
|         this.map = map; |  | ||||||
|         this.acsending = acsending; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value |  | ||||||
| 
 |  | ||||||
|         Set mapEntrySet = map.entrySet(); |  | ||||||
|         List keyList = new LinkedList(); |  | ||||||
| 
 |  | ||||||
|         // if the numKeys is larger than map size, limit it |  | ||||||
|         if (numKeys > map.size()) { |  | ||||||
|             numKeys = map.size(); |  | ||||||
|         } |  | ||||||
|         // end if the numKeys is larger than map size, limit it |  | ||||||
| 
 |  | ||||||
|         if (map.size() > 0) { |  | ||||||
|             double[] array = new double[map.size()]; |  | ||||||
|             int count = 0; |  | ||||||
| 
 |  | ||||||
|             // get the pass values |  | ||||||
|             Iterator<Entry> mapIterator = mapEntrySet.iterator(); |  | ||||||
|             while (mapIterator.hasNext()) { |  | ||||||
|                 Entry entry = mapIterator.next(); |  | ||||||
|                 array[count++] = (Double) entry.getValue(); |  | ||||||
|             } |  | ||||||
|             // end get the pass values |  | ||||||
| 
 |  | ||||||
|             int targetindex; |  | ||||||
|             if (acsending) { |  | ||||||
|                 targetindex = numKeys; |  | ||||||
|             } else { |  | ||||||
|                 targetindex = array.length - numKeys; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             double passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element |  | ||||||
|             // get the passed keys and values |  | ||||||
|             Map passedMap = new HashMap(); |  | ||||||
|             List<Double> valueList = new LinkedList<Double>(); |  | ||||||
|             mapIterator = mapEntrySet.iterator(); |  | ||||||
| 
 |  | ||||||
|             while (mapIterator.hasNext()) { |  | ||||||
|                 Entry entry = mapIterator.next(); |  | ||||||
|                 double value = (Double) entry.getValue(); |  | ||||||
|                 if ((acsending && value <= passValue) || (!acsending && value >= passValue)) { |  | ||||||
|                     passedMap.put(entry.getKey(), value); |  | ||||||
|                     valueList.add(value); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             // end get the passed keys and values |  | ||||||
| 
 |  | ||||||
|             // sort the value list |  | ||||||
|             Double[] listArr = new Double[valueList.size()]; |  | ||||||
|             valueList.toArray(listArr); |  | ||||||
|             Arrays.sort(listArr); |  | ||||||
|             // end sort the value list |  | ||||||
| 
 |  | ||||||
|             // get the list of keys |  | ||||||
|             int resultCount = 0; |  | ||||||
|             int index; |  | ||||||
|             if (acsending) { |  | ||||||
|                 index = 0; |  | ||||||
|             } else { |  | ||||||
|                 index = listArr.length - 1; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             if (!sharpLimit) { |  | ||||||
|                 numKeys = listArr.length; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             while (true) { |  | ||||||
|                 double targetValue = (Double) listArr[index]; |  | ||||||
|                 Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator(); |  | ||||||
|                 while (passedMapIterator.hasNext()) { |  | ||||||
|                     Entry entry = passedMapIterator.next(); |  | ||||||
|                     if ((Double) entry.getValue() == targetValue) { |  | ||||||
|                         keyList.add(entry.getKey()); |  | ||||||
|                         passedMapIterator.remove(); |  | ||||||
|                         resultCount++; |  | ||||||
|                         break; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if (acsending) { |  | ||||||
|                     index++; |  | ||||||
|                 } else { |  | ||||||
|                     index--; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if (resultCount >= numKeys) { |  | ||||||
|                     break; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             // end get the list of keys |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return keyList; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private double getOrderedValue(double[] array, int index) { |  | ||||||
|         locate(array, 0, array.length - 1, index); |  | ||||||
|         return array[index]; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // sort the partitions by quick sort, and locate the target index |  | ||||||
|     private void locate(double[] array, int left, int right, int index) { |  | ||||||
| 
 |  | ||||||
|         int mid = (left + right) / 2; |  | ||||||
|         //System.out.println(left+" to "+right+" ("+mid+")"); |  | ||||||
| 
 |  | ||||||
|         if (right == left) { |  | ||||||
|             //System.out.println("* "+array[targetIndex]); |  | ||||||
|             //result=array[targetIndex]; |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (left < right) { |  | ||||||
|             double s = array[mid]; |  | ||||||
|             int i = left - 1; |  | ||||||
|             int j = right + 1; |  | ||||||
| 
 |  | ||||||
|             while (true) { |  | ||||||
|                 while (array[++i] < s); |  | ||||||
|                 while (array[--j] > s); |  | ||||||
|                 if (i >= j) |  | ||||||
|                     break; |  | ||||||
|                 swap(array, i, j); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             //System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); |  | ||||||
| 
 |  | ||||||
|             if (i > index) { |  | ||||||
|                 // the target index in the left partition |  | ||||||
|                 //System.out.println("left partition"); |  | ||||||
|                 locate(array, left, i - 1, index); |  | ||||||
|             } else { |  | ||||||
|                 // the target index in the right partition |  | ||||||
|                 //System.out.println("right partition"); |  | ||||||
|                 locate(array, j + 1, right, index); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void swap(double[] array, int i, int j) { |  | ||||||
|         double t = array[i]; |  | ||||||
|         array[i] = array[j]; |  | ||||||
|         array[j] = t; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,179 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| import java.util.Map.Entry; |  | ||||||
| 
 |  | ||||||
| public class MapRankInteger implements MapRank { |  | ||||||
| 
 |  | ||||||
|     private Map map; |  | ||||||
|     private boolean acsending = true; |  | ||||||
| 
 |  | ||||||
|     public MapRankInteger(Map<?, Integer> map, boolean acsending) { |  | ||||||
|         this.map = map; |  | ||||||
|         this.acsending = acsending; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value |  | ||||||
| 
 |  | ||||||
|         Set mapEntrySet = map.entrySet(); |  | ||||||
|         List keyList = new LinkedList(); |  | ||||||
| 
 |  | ||||||
|         // if the numKeys is larger than map size, limit it |  | ||||||
|         if (numKeys > map.size()) { |  | ||||||
|             numKeys = map.size(); |  | ||||||
|         } |  | ||||||
|         // end if the numKeys is larger than map size, limit it |  | ||||||
| 
 |  | ||||||
|         if (map.size() > 0) { |  | ||||||
|             int[] array = new int[map.size()]; |  | ||||||
|             int count = 0; |  | ||||||
| 
 |  | ||||||
|             // get the pass values |  | ||||||
|             Iterator<Entry> mapIterator = mapEntrySet.iterator(); |  | ||||||
|             while (mapIterator.hasNext()) { |  | ||||||
|                 Entry entry = mapIterator.next(); |  | ||||||
|                 array[count++] = (Integer) entry.getValue(); |  | ||||||
|             } |  | ||||||
|             // end get the pass values |  | ||||||
| 
 |  | ||||||
|             int targetindex; |  | ||||||
|             if (acsending) { |  | ||||||
|                 targetindex = numKeys; |  | ||||||
|             } else { |  | ||||||
|                 targetindex = array.length - numKeys; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             int passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element |  | ||||||
|             // get the passed keys and values |  | ||||||
|             Map passedMap = new HashMap(); |  | ||||||
|             List<Integer> valueList = new LinkedList<Integer>(); |  | ||||||
|             mapIterator = mapEntrySet.iterator(); |  | ||||||
| 
 |  | ||||||
|             while (mapIterator.hasNext()) { |  | ||||||
|                 Entry entry = mapIterator.next(); |  | ||||||
|                 int value = (Integer) entry.getValue(); |  | ||||||
|                 if ((acsending && value <= passValue) || (!acsending && value >= passValue)) { |  | ||||||
|                     passedMap.put(entry.getKey(), value); |  | ||||||
|                     valueList.add(value); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             // end get the passed keys and values |  | ||||||
| 
 |  | ||||||
|             // sort the value list |  | ||||||
|             Integer[] listArr = new Integer[valueList.size()]; |  | ||||||
|             valueList.toArray(listArr); |  | ||||||
|             Arrays.sort(listArr); |  | ||||||
|             // end sort the value list |  | ||||||
| 
 |  | ||||||
|             // get the list of keys |  | ||||||
|             int resultCount = 0; |  | ||||||
|             int index; |  | ||||||
|             if (acsending) { |  | ||||||
|                 index = 0; |  | ||||||
|             } else { |  | ||||||
|                 index = listArr.length - 1; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             if (!sharpLimit) { |  | ||||||
|                 numKeys = listArr.length; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             while (true) { |  | ||||||
|                 int targetValue = (Integer) listArr[index]; |  | ||||||
|                 Iterator<Entry> passedMapIterator = passedMap.entrySet().iterator(); |  | ||||||
|                 while (passedMapIterator.hasNext()) { |  | ||||||
|                     Entry entry = passedMapIterator.next(); |  | ||||||
|                     if ((Integer) entry.getValue() == targetValue) { |  | ||||||
|                         keyList.add(entry.getKey()); |  | ||||||
|                         passedMapIterator.remove(); |  | ||||||
|                         resultCount++; |  | ||||||
|                         break; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if (acsending) { |  | ||||||
|                     index++; |  | ||||||
|                 } else { |  | ||||||
|                     index--; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if (resultCount >= numKeys) { |  | ||||||
|                     break; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             // end get the list of keys |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return keyList; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private int getOrderedValue(int[] array, int index) { |  | ||||||
|         locate(array, 0, array.length - 1, index); |  | ||||||
|         return array[index]; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // sort the partitions by quick sort, and locate the target index |  | ||||||
|     private void locate(int[] array, int left, int right, int index) { |  | ||||||
| 
 |  | ||||||
|         int mid = (left + right) / 2; |  | ||||||
|         //System.out.println(left+" to "+right+" ("+mid+")"); |  | ||||||
| 
 |  | ||||||
|         if (right == left) { |  | ||||||
|             //System.out.println("* "+array[targetIndex]); |  | ||||||
|             //result=array[targetIndex]; |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (left < right) { |  | ||||||
|             int s = array[mid]; |  | ||||||
|             int i = left - 1; |  | ||||||
|             int j = right + 1; |  | ||||||
| 
 |  | ||||||
|             while (true) { |  | ||||||
|                 while (array[++i] < s); |  | ||||||
|                 while (array[--j] > s); |  | ||||||
|                 if (i >= j) |  | ||||||
|                     break; |  | ||||||
|                 swap(array, i, j); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             //System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); |  | ||||||
| 
 |  | ||||||
|             if (i > index) { |  | ||||||
|                 // the target index in the left partition |  | ||||||
|                 //System.out.println("left partition"); |  | ||||||
|                 locate(array, left, i - 1, index); |  | ||||||
|             } else { |  | ||||||
|                 // the target index in the right partition |  | ||||||
|                 //System.out.println("right partition"); |  | ||||||
|                 locate(array, j + 1, right, index); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void swap(int[] array, int i, int j) { |  | ||||||
|         int t = array[i]; |  | ||||||
|         array[i] = array[j]; |  | ||||||
|         array[j] = t; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,230 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.audio.properties.FingerprintProperties; |  | ||||||
| 
 |  | ||||||
| import java.util.HashMap; |  | ||||||
| import java.util.LinkedList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class PairManager { |  | ||||||
| 
 |  | ||||||
|     FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); |  | ||||||
|     private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); |  | ||||||
|     private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks; |  | ||||||
|     private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength(); |  | ||||||
|     private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval(); |  | ||||||
|     private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance(); |  | ||||||
|     private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits(); |  | ||||||
| 
 |  | ||||||
|     private int maxPairs; |  | ||||||
|     private boolean isReferencePairing; |  | ||||||
|     private HashMap<Integer, Boolean> stopPairTable = new HashMap<>(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor |  | ||||||
|      */ |  | ||||||
|     public PairManager() { |  | ||||||
|         maxPairs = fingerprintProperties.getRefMaxActivePairs(); |  | ||||||
|         isReferencePairing = true; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructor, number of pairs of robust points depends on the parameter isReferencePairing |  | ||||||
|      * no. of pairs of reference and sample can be different due to environmental influence of source   |  | ||||||
|      * @param isReferencePairing |  | ||||||
|      */ |  | ||||||
|     public PairManager(boolean isReferencePairing) { |  | ||||||
|         if (isReferencePairing) { |  | ||||||
|             maxPairs = fingerprintProperties.getRefMaxActivePairs(); |  | ||||||
|         } else { |  | ||||||
|             maxPairs = fingerprintProperties.getSampleMaxActivePairs(); |  | ||||||
|         } |  | ||||||
|         this.isReferencePairing = isReferencePairing; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get a pair-positionList table |  | ||||||
|      * It's a hash map which the key is the hashed pair, and the value is list of positions |  | ||||||
|      * That means the table stores the positions which have the same hashed pair |  | ||||||
|      *  |  | ||||||
|      * @param fingerprint	fingerprint bytes |  | ||||||
|      * @return pair-positionList HashMap |  | ||||||
|      */ |  | ||||||
|     public HashMap<Integer, List<Integer>> getPair_PositionList_Table(byte[] fingerprint) { |  | ||||||
| 
 |  | ||||||
|         List<int[]> pairPositionList = getPairPositionList(fingerprint); |  | ||||||
| 
 |  | ||||||
|         // table to store pair:pos,pos,pos,...;pair2:pos,pos,pos,.... |  | ||||||
|         HashMap<Integer, List<Integer>> pair_positionList_table = new HashMap<>(); |  | ||||||
| 
 |  | ||||||
|         // get all pair_positions from list, use a table to collect the data group by pair hashcode |  | ||||||
|         for (int[] pair_position : pairPositionList) { |  | ||||||
|             //System.out.println(pair_position[0]+","+pair_position[1]); |  | ||||||
| 
 |  | ||||||
|             // group by pair-hashcode, i.e.: <pair,List<position>> |  | ||||||
|             if (pair_positionList_table.containsKey(pair_position[0])) { |  | ||||||
|                 pair_positionList_table.get(pair_position[0]).add(pair_position[1]); |  | ||||||
|             } else { |  | ||||||
|                 List<Integer> positionList = new LinkedList<>(); |  | ||||||
|                 positionList.add(pair_position[1]); |  | ||||||
|                 pair_positionList_table.put(pair_position[0], positionList); |  | ||||||
|             } |  | ||||||
|             // end group by pair-hashcode, i.e.: <pair,List<position>> |  | ||||||
|         } |  | ||||||
|         // end get all pair_positions from list, use a table to collect the data group by pair hashcode |  | ||||||
| 
 |  | ||||||
|         return pair_positionList_table; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // this return list contains: int[0]=pair_hashcode, int[1]=position |  | ||||||
|     private List<int[]> getPairPositionList(byte[] fingerprint) { |  | ||||||
| 
 |  | ||||||
|         int numFrames = FingerprintManager.getNumFrames(fingerprint); |  | ||||||
| 
 |  | ||||||
|         // table for paired frames |  | ||||||
|         byte[] pairedFrameTable = new byte[numFrames / anchorPointsIntervalLength + 1]; // each second has numAnchorPointsPerSecond pairs only |  | ||||||
|         // end table for paired frames |  | ||||||
| 
 |  | ||||||
|         List<int[]> pairList = new LinkedList<>(); |  | ||||||
|         List<int[]> sortedCoordinateList = getSortedCoordinateList(fingerprint); |  | ||||||
| 
 |  | ||||||
|         for (int[] anchorPoint : sortedCoordinateList) { |  | ||||||
|             int anchorX = anchorPoint[0]; |  | ||||||
|             int anchorY = anchorPoint[1]; |  | ||||||
|             int numPairs = 0; |  | ||||||
| 
 |  | ||||||
|             for (int[] aSortedCoordinateList : sortedCoordinateList) { |  | ||||||
| 
 |  | ||||||
|                 if (numPairs >= maxPairs) { |  | ||||||
|                     break; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if (isReferencePairing && pairedFrameTable[anchorX |  | ||||||
|                                 / anchorPointsIntervalLength] >= numAnchorPointsPerInterval) { |  | ||||||
|                     break; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 int targetX = aSortedCoordinateList[0]; |  | ||||||
|                 int targetY = aSortedCoordinateList[1]; |  | ||||||
| 
 |  | ||||||
|                 if (anchorX == targetX && anchorY == targetY) { |  | ||||||
|                     continue; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 // pair up the points |  | ||||||
|                 int x1, y1, x2, y2; // x2 always >= x1 |  | ||||||
|                 if (targetX >= anchorX) { |  | ||||||
|                     x2 = targetX; |  | ||||||
|                     y2 = targetY; |  | ||||||
|                     x1 = anchorX; |  | ||||||
|                     y1 = anchorY; |  | ||||||
|                 } else { |  | ||||||
|                     x2 = anchorX; |  | ||||||
|                     y2 = anchorY; |  | ||||||
|                     x1 = targetX; |  | ||||||
|                     y1 = targetY; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 // check target zone |  | ||||||
|                 if ((x2 - x1) > maxTargetZoneDistance) { |  | ||||||
|                     continue; |  | ||||||
|                 } |  | ||||||
|                 // end check target zone |  | ||||||
| 
 |  | ||||||
|                 // check filter bank zone |  | ||||||
|                 if (!(y1 / bandwidthPerBank == y2 / bandwidthPerBank)) { |  | ||||||
|                     continue; // same filter bank should have equal value |  | ||||||
|                 } |  | ||||||
|                 // end check filter bank zone |  | ||||||
| 
 |  | ||||||
|                 int pairHashcode = (x2 - x1) * numFrequencyUnits * numFrequencyUnits + y2 * numFrequencyUnits + y1; |  | ||||||
| 
 |  | ||||||
|                 // stop list applied on sample pairing only |  | ||||||
|                 if (!isReferencePairing && stopPairTable.containsKey(pairHashcode)) { |  | ||||||
|                     numPairs++; // no reservation |  | ||||||
|                     continue; // escape this point only |  | ||||||
|                 } |  | ||||||
|                 // end stop list applied on sample pairing only |  | ||||||
| 
 |  | ||||||
|                 // pass all rules |  | ||||||
|                 pairList.add(new int[] {pairHashcode, anchorX}); |  | ||||||
|                 pairedFrameTable[anchorX / anchorPointsIntervalLength]++; |  | ||||||
|                 numPairs++; |  | ||||||
|                 // end pair up the points |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return pairList; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private List<int[]> getSortedCoordinateList(byte[] fingerprint) { |  | ||||||
|         // each point data is 8 bytes  |  | ||||||
|         // first 2 bytes is x |  | ||||||
|         // next 2 bytes is y |  | ||||||
|         // next 4 bytes is intensity |  | ||||||
| 
 |  | ||||||
|         // get all intensities |  | ||||||
|         int numCoordinates = fingerprint.length / 8; |  | ||||||
|         int[] intensities = new int[numCoordinates]; |  | ||||||
|         for (int i = 0; i < numCoordinates; i++) { |  | ||||||
|             int pointer = i * 8 + 4; |  | ||||||
|             int intensity = (fingerprint[pointer] & 0xff) << 24 | (fingerprint[pointer + 1] & 0xff) << 16 |  | ||||||
|                             | (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff); |  | ||||||
|             intensities[i] = intensity; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         QuickSortIndexPreserved quicksort = new QuickSortIndexPreserved(intensities); |  | ||||||
|         int[] sortIndexes = quicksort.getSortIndexes(); |  | ||||||
| 
 |  | ||||||
|         List<int[]> sortedCoordinateList = new LinkedList<>(); |  | ||||||
|         for (int i = sortIndexes.length - 1; i >= 0; i--) { |  | ||||||
|             int pointer = sortIndexes[i] * 8; |  | ||||||
|             int x = (fingerprint[pointer] & 0xff) << 8 | (fingerprint[pointer + 1] & 0xff); |  | ||||||
|             int y = (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff); |  | ||||||
|             sortedCoordinateList.add(new int[] {x, y}); |  | ||||||
|         } |  | ||||||
|         return sortedCoordinateList; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Convert hashed pair to bytes |  | ||||||
|      *  |  | ||||||
|      * @param pairHashcode hashed pair |  | ||||||
|      * @return byte array |  | ||||||
|      */ |  | ||||||
|     public static byte[] pairHashcodeToBytes(int pairHashcode) { |  | ||||||
|         return new byte[] {(byte) (pairHashcode >> 8), (byte) pairHashcode}; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Convert bytes to hased pair |  | ||||||
|      *  |  | ||||||
|      * @param pairBytes |  | ||||||
|      * @return hashed pair |  | ||||||
|      */ |  | ||||||
|     public static int pairBytesToHashcode(byte[] pairBytes) { |  | ||||||
|         return (pairBytes[0] & 0xFF) << 8 | (pairBytes[1] & 0xFF); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,25 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| public abstract class QuickSort { |  | ||||||
|     public abstract int[] getSortIndexes(); |  | ||||||
| } |  | ||||||
| @ -1,79 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| public class QuickSortDouble extends QuickSort { |  | ||||||
| 
 |  | ||||||
|     private int[] indexes; |  | ||||||
|     private double[] array; |  | ||||||
| 
 |  | ||||||
|     public QuickSortDouble(double[] array) { |  | ||||||
|         this.array = array; |  | ||||||
|         indexes = new int[array.length]; |  | ||||||
|         for (int i = 0; i < indexes.length; i++) { |  | ||||||
|             indexes[i] = i; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int[] getSortIndexes() { |  | ||||||
|         sort(); |  | ||||||
|         return indexes; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void sort() { |  | ||||||
|         quicksort(array, indexes, 0, indexes.length - 1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // quicksort a[left] to a[right] |  | ||||||
|     private void quicksort(double[] a, int[] indexes, int left, int right) { |  | ||||||
|         if (right <= left) |  | ||||||
|             return; |  | ||||||
|         int i = partition(a, indexes, left, right); |  | ||||||
|         quicksort(a, indexes, left, i - 1); |  | ||||||
|         quicksort(a, indexes, i + 1, right); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // partition a[left] to a[right], assumes left < right |  | ||||||
|     private int partition(double[] a, int[] indexes, int left, int right) { |  | ||||||
|         int i = left - 1; |  | ||||||
|         int j = right; |  | ||||||
|         while (true) { |  | ||||||
|             while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel |  | ||||||
|             while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap |  | ||||||
|                 if (j == left) |  | ||||||
|                     break; // don't go out-of-bounds |  | ||||||
|             } |  | ||||||
|             if (i >= j) |  | ||||||
|                 break; // check if pointers cross |  | ||||||
|             swap(a, indexes, i, j); // swap two elements into place |  | ||||||
|         } |  | ||||||
|         swap(a, indexes, i, right); // swap with partition element |  | ||||||
|         return i; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // exchange a[i] and a[j] |  | ||||||
|     private void swap(double[] a, int[] indexes, int i, int j) { |  | ||||||
|         int swap = indexes[i]; |  | ||||||
|         indexes[i] = indexes[j]; |  | ||||||
|         indexes[j] = swap; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,43 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| public class QuickSortIndexPreserved { |  | ||||||
| 
 |  | ||||||
|     private QuickSort quickSort; |  | ||||||
| 
 |  | ||||||
|     public QuickSortIndexPreserved(int[] array) { |  | ||||||
|         quickSort = new QuickSortInteger(array); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuickSortIndexPreserved(double[] array) { |  | ||||||
|         quickSort = new QuickSortDouble(array); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuickSortIndexPreserved(short[] array) { |  | ||||||
|         quickSort = new QuickSortShort(array); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int[] getSortIndexes() { |  | ||||||
|         return quickSort.getSortIndexes(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,79 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| public class QuickSortInteger extends QuickSort { |  | ||||||
| 
 |  | ||||||
|     private int[] indexes; |  | ||||||
|     private int[] array; |  | ||||||
| 
 |  | ||||||
|     public QuickSortInteger(int[] array) { |  | ||||||
|         this.array = array; |  | ||||||
|         indexes = new int[array.length]; |  | ||||||
|         for (int i = 0; i < indexes.length; i++) { |  | ||||||
|             indexes[i] = i; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int[] getSortIndexes() { |  | ||||||
|         sort(); |  | ||||||
|         return indexes; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void sort() { |  | ||||||
|         quicksort(array, indexes, 0, indexes.length - 1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // quicksort a[left] to a[right] |  | ||||||
|     private void quicksort(int[] a, int[] indexes, int left, int right) { |  | ||||||
|         if (right <= left) |  | ||||||
|             return; |  | ||||||
|         int i = partition(a, indexes, left, right); |  | ||||||
|         quicksort(a, indexes, left, i - 1); |  | ||||||
|         quicksort(a, indexes, i + 1, right); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // partition a[left] to a[right], assumes left < right |  | ||||||
|     private int partition(int[] a, int[] indexes, int left, int right) { |  | ||||||
|         int i = left - 1; |  | ||||||
|         int j = right; |  | ||||||
|         while (true) { |  | ||||||
|             while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel |  | ||||||
|             while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap |  | ||||||
|                 if (j == left) |  | ||||||
|                     break; // don't go out-of-bounds |  | ||||||
|             } |  | ||||||
|             if (i >= j) |  | ||||||
|                 break; // check if pointers cross |  | ||||||
|             swap(a, indexes, i, j); // swap two elements into place |  | ||||||
|         } |  | ||||||
|         swap(a, indexes, i, right); // swap with partition element |  | ||||||
|         return i; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // exchange a[i] and a[j] |  | ||||||
|     private void swap(int[] a, int[] indexes, int i, int j) { |  | ||||||
|         int swap = indexes[i]; |  | ||||||
|         indexes[i] = indexes[j]; |  | ||||||
|         indexes[j] = swap; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,79 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.fingerprint; |  | ||||||
| 
 |  | ||||||
| public class QuickSortShort extends QuickSort { |  | ||||||
| 
 |  | ||||||
|     private int[] indexes; |  | ||||||
|     private short[] array; |  | ||||||
| 
 |  | ||||||
|     public QuickSortShort(short[] array) { |  | ||||||
|         this.array = array; |  | ||||||
|         indexes = new int[array.length]; |  | ||||||
|         for (int i = 0; i < indexes.length; i++) { |  | ||||||
|             indexes[i] = i; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int[] getSortIndexes() { |  | ||||||
|         sort(); |  | ||||||
|         return indexes; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void sort() { |  | ||||||
|         quicksort(array, indexes, 0, indexes.length - 1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // quicksort a[left] to a[right] |  | ||||||
|     private void quicksort(short[] a, int[] indexes, int left, int right) { |  | ||||||
|         if (right <= left) |  | ||||||
|             return; |  | ||||||
|         int i = partition(a, indexes, left, right); |  | ||||||
|         quicksort(a, indexes, left, i - 1); |  | ||||||
|         quicksort(a, indexes, i + 1, right); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // partition a[left] to a[right], assumes left < right |  | ||||||
|     private int partition(short[] a, int[] indexes, int left, int right) { |  | ||||||
|         int i = left - 1; |  | ||||||
|         int j = right; |  | ||||||
|         while (true) { |  | ||||||
|             while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel |  | ||||||
|             while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap |  | ||||||
|                 if (j == left) |  | ||||||
|                     break; // don't go out-of-bounds |  | ||||||
|             } |  | ||||||
|             if (i >= j) |  | ||||||
|                 break; // check if pointers cross |  | ||||||
|             swap(a, indexes, i, j); // swap two elements into place |  | ||||||
|         } |  | ||||||
|         swap(a, indexes, i, right); // swap with partition element |  | ||||||
|         return i; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // exchange a[i] and a[j] |  | ||||||
|     private void swap(short[] a, int[] indexes, int i, int j) { |  | ||||||
|         int swap = indexes[i]; |  | ||||||
|         indexes[i] = indexes[j]; |  | ||||||
|         indexes[j] = swap; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,51 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.formats.input; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.formats.input.BaseInputFormat; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.api.split.InputSplit; |  | ||||||
| import org.datavec.audio.recordreader.WavFileRecordReader; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  * Wave file input format |  | ||||||
|  * |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class WavInputFormat extends BaseInputFormat { |  | ||||||
|     @Override |  | ||||||
|     public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { |  | ||||||
|         return createReader(split); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public RecordReader createReader(InputSplit split) throws IOException, InterruptedException { |  | ||||||
|         RecordReader waveRecordReader = new WavFileRecordReader(); |  | ||||||
|         waveRecordReader.initialize(split); |  | ||||||
|         return waveRecordReader; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,36 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.formats.output; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.exceptions.DataVecException; |  | ||||||
| import org.datavec.api.formats.output.OutputFormat; |  | ||||||
| import org.datavec.api.records.writer.RecordWriter; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class WaveOutputFormat implements OutputFormat { |  | ||||||
|     @Override |  | ||||||
|     public RecordWriter createWriter(Configuration conf) throws DataVecException { |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,139 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.processor; |  | ||||||
| 
 |  | ||||||
| public class ArrayRankDouble { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the index position of maximum value the given array  |  | ||||||
|      * @param array an array |  | ||||||
|      * @return	index of the max value in array |  | ||||||
|      */ |  | ||||||
|     public int getMaxValueIndex(double[] array) { |  | ||||||
| 
 |  | ||||||
|         int index = 0; |  | ||||||
|         double max = Integer.MIN_VALUE; |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < array.length; i++) { |  | ||||||
|             if (array[i] > max) { |  | ||||||
|                 max = array[i]; |  | ||||||
|                 index = i; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return index; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the index position of minimum value in the given array  |  | ||||||
|      * @param array an array |  | ||||||
|      * @return	index of the min value in array |  | ||||||
|      */ |  | ||||||
|     public int getMinValueIndex(double[] array) { |  | ||||||
| 
 |  | ||||||
|         int index = 0; |  | ||||||
|         double min = Integer.MAX_VALUE; |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < array.length; i++) { |  | ||||||
|             if (array[i] < min) { |  | ||||||
|                 min = array[i]; |  | ||||||
|                 index = i; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return index; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the n-th value in the array after sorted |  | ||||||
|      * @param array an array |  | ||||||
|      * @param n position in array |  | ||||||
|      * @param ascending	is ascending order or not |  | ||||||
|      * @return value at nth position of array |  | ||||||
|      */ |  | ||||||
|     public double getNthOrderedValue(double[] array, int n, boolean ascending) { |  | ||||||
| 
 |  | ||||||
|         if (n > array.length) { |  | ||||||
|             n = array.length; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         int targetindex; |  | ||||||
|         if (ascending) { |  | ||||||
|             targetindex = n; |  | ||||||
|         } else { |  | ||||||
|             targetindex = array.length - n; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // this value is the value of the numKey-th element |  | ||||||
| 
 |  | ||||||
|         return getOrderedValue(array, targetindex); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private double getOrderedValue(double[] array, int index) { |  | ||||||
|         locate(array, 0, array.length - 1, index); |  | ||||||
|         return array[index]; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // sort the partitions by quick sort, and locate the target index |  | ||||||
|     private void locate(double[] array, int left, int right, int index) { |  | ||||||
| 
 |  | ||||||
|         int mid = (left + right) / 2; |  | ||||||
|         // System.out.println(left+" to "+right+" ("+mid+")"); |  | ||||||
| 
 |  | ||||||
|         if (right == left) { |  | ||||||
|             // System.out.println("* "+array[targetIndex]); |  | ||||||
|             // result=array[targetIndex]; |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (left < right) { |  | ||||||
|             double s = array[mid]; |  | ||||||
|             int i = left - 1; |  | ||||||
|             int j = right + 1; |  | ||||||
| 
 |  | ||||||
|             while (true) { |  | ||||||
|                 while (array[++i] < s); |  | ||||||
|                 while (array[--j] > s); |  | ||||||
|                 if (i >= j) |  | ||||||
|                     break; |  | ||||||
|                 swap(array, i, j); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             // System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); |  | ||||||
| 
 |  | ||||||
|             if (i > index) { |  | ||||||
|                 // the target index in the left partition |  | ||||||
|                 // System.out.println("left partition"); |  | ||||||
|                 locate(array, left, i - 1, index); |  | ||||||
|             } else { |  | ||||||
|                 // the target index in the right partition |  | ||||||
|                 // System.out.println("right partition"); |  | ||||||
|                 locate(array, j + 1, right, index); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void swap(double[] array, int i, int j) { |  | ||||||
|         double t = array[i]; |  | ||||||
|         array[i] = array[j]; |  | ||||||
|         array[j] = t; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,28 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.processor; |  | ||||||
| 
 |  | ||||||
| public interface IntensityProcessor { |  | ||||||
| 
 |  | ||||||
|     public void execute(); |  | ||||||
| 
 |  | ||||||
|     public double[][] getIntensities(); |  | ||||||
| } |  | ||||||
| @ -1,48 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.processor; |  | ||||||
| 
 |  | ||||||
| import java.util.LinkedList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class ProcessorChain { |  | ||||||
| 
 |  | ||||||
|     private double[][] intensities; |  | ||||||
|     List<IntensityProcessor> processorList = new LinkedList<IntensityProcessor>(); |  | ||||||
| 
 |  | ||||||
|     public ProcessorChain(double[][] intensities) { |  | ||||||
|         this.intensities = intensities; |  | ||||||
|         RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, 1); |  | ||||||
|         processorList.add(robustProcessor); |  | ||||||
|         process(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void process() { |  | ||||||
|         for (IntensityProcessor processor : processorList) { |  | ||||||
|             processor.execute(); |  | ||||||
|             intensities = processor.getIntensities(); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double[][] getIntensities() { |  | ||||||
|         return intensities; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,61 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.processor; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class RobustIntensityProcessor implements IntensityProcessor { |  | ||||||
| 
 |  | ||||||
|     private double[][] intensities; |  | ||||||
|     private int numPointsPerFrame; |  | ||||||
| 
 |  | ||||||
|     public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) { |  | ||||||
|         this.intensities = intensities; |  | ||||||
|         this.numPointsPerFrame = numPointsPerFrame; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void execute() { |  | ||||||
| 
 |  | ||||||
|         int numX = intensities.length; |  | ||||||
|         int numY = intensities[0].length; |  | ||||||
|         double[][] processedIntensities = new double[numX][numY]; |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < numX; i++) { |  | ||||||
|             double[] tmpArray = new double[numY]; |  | ||||||
|             System.arraycopy(intensities[i], 0, tmpArray, 0, numY); |  | ||||||
| 
 |  | ||||||
|             // pass value is the last some elements in sorted array	 |  | ||||||
|             ArrayRankDouble arrayRankDouble = new ArrayRankDouble(); |  | ||||||
|             double passValue = arrayRankDouble.getNthOrderedValue(tmpArray, numPointsPerFrame, false); |  | ||||||
| 
 |  | ||||||
|             // only passed elements will be assigned a value |  | ||||||
|             for (int j = 0; j < numY; j++) { |  | ||||||
|                 if (intensities[i][j] >= passValue) { |  | ||||||
|                     processedIntensities[i][j] = intensities[i][j]; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         intensities = processedIntensities; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double[][] getIntensities() { |  | ||||||
|         return intensities; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,49 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.processor; |  | ||||||
| 
 |  | ||||||
| import java.util.LinkedList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class TopManyPointsProcessorChain { |  | ||||||
| 
 |  | ||||||
|     private double[][] intensities; |  | ||||||
|     List<IntensityProcessor> processorList = new LinkedList<>(); |  | ||||||
| 
 |  | ||||||
|     public TopManyPointsProcessorChain(double[][] intensities, int numPoints) { |  | ||||||
|         this.intensities = intensities; |  | ||||||
|         RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, numPoints); |  | ||||||
|         processorList.add(robustProcessor); |  | ||||||
|         process(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void process() { |  | ||||||
|         for (IntensityProcessor processor : processorList) { |  | ||||||
|             processor.execute(); |  | ||||||
|             intensities = processor.getIntensities(); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double[][] getIntensities() { |  | ||||||
|         return intensities; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,121 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.properties; |  | ||||||
| 
 |  | ||||||
| public class FingerprintProperties { |  | ||||||
| 
 |  | ||||||
|     protected static FingerprintProperties instance = null; |  | ||||||
| 
 |  | ||||||
|     private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint |  | ||||||
|     private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size |  | ||||||
|     private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ...	32 |  | ||||||
|     private int numFilterBanks = 4; |  | ||||||
| 
 |  | ||||||
|     private int upperBoundedFrequency = 1500; // low pass |  | ||||||
|     private int lowerBoundedFrequency = 400; // high pass |  | ||||||
|     private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps) |  | ||||||
|     private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps |  | ||||||
|     private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually   |  | ||||||
| 
 |  | ||||||
|     private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs |  | ||||||
|     private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip |  | ||||||
|     private int numAnchorPointsPerInterval = 10; |  | ||||||
|     private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second) |  | ||||||
|     private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second) |  | ||||||
| 
 |  | ||||||
|     private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units |  | ||||||
| 
 |  | ||||||
|     public static FingerprintProperties getInstance() { |  | ||||||
|         if (instance == null) { |  | ||||||
|             synchronized (FingerprintProperties.class) { |  | ||||||
|                 if (instance == null) { |  | ||||||
|                     instance = new FingerprintProperties(); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return instance; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumRobustPointsPerFrame() { |  | ||||||
|         return numRobustPointsPerFrame; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getSampleSizePerFrame() { |  | ||||||
|         return sampleSizePerFrame; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getOverlapFactor() { |  | ||||||
|         return overlapFactor; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumFilterBanks() { |  | ||||||
|         return numFilterBanks; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getUpperBoundedFrequency() { |  | ||||||
|         return upperBoundedFrequency; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getLowerBoundedFrequency() { |  | ||||||
|         return lowerBoundedFrequency; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getFps() { |  | ||||||
|         return fps; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getRefMaxActivePairs() { |  | ||||||
|         return refMaxActivePairs; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getSampleMaxActivePairs() { |  | ||||||
|         return sampleMaxActivePairs; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumAnchorPointsPerInterval() { |  | ||||||
|         return numAnchorPointsPerInterval; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getAnchorPointsIntervalLength() { |  | ||||||
|         return anchorPointsIntervalLength; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getMaxTargetZoneDistance() { |  | ||||||
|         return maxTargetZoneDistance; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumFrequencyUnits() { |  | ||||||
|         return numFrequencyUnits; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getMaxPossiblePairHashcode() { |  | ||||||
|         return maxTargetZoneDistance * numFrequencyUnits * numFrequencyUnits + numFrequencyUnits * numFrequencyUnits |  | ||||||
|                         + numFrequencyUnits; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getSampleRate() { |  | ||||||
|         return sampleRate; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumFramesInOneSecond() { |  | ||||||
|         return numFramesInOneSecond; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,225 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.recordreader; |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.Record; |  | ||||||
| import org.datavec.api.records.metadata.RecordMetaData; |  | ||||||
| import org.datavec.api.records.reader.BaseRecordReader; |  | ||||||
| import org.datavec.api.split.BaseInputSplit; |  | ||||||
| import org.datavec.api.split.InputSplit; |  | ||||||
| import org.datavec.api.split.InputStreamInputSplit; |  | ||||||
| import org.datavec.api.writable.DoubleWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| 
 |  | ||||||
| import java.io.DataInputStream; |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.net.URI; |  | ||||||
| import java.nio.file.Path; |  | ||||||
| import java.nio.file.Paths; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.Iterator; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Base audio file loader |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public abstract class BaseAudioRecordReader extends BaseRecordReader { |  | ||||||
|     private Iterator<File> iter; |  | ||||||
|     private List<Writable> record; |  | ||||||
|     private boolean hitImage = false; |  | ||||||
|     private boolean appendLabel = false; |  | ||||||
|     private List<String> labels = new ArrayList<>(); |  | ||||||
|     private Configuration conf; |  | ||||||
|     protected InputSplit inputSplit; |  | ||||||
| 
 |  | ||||||
|     public BaseAudioRecordReader() {} |  | ||||||
| 
 |  | ||||||
|     public BaseAudioRecordReader(boolean appendLabel, List<String> labels) { |  | ||||||
|         this.appendLabel = appendLabel; |  | ||||||
|         this.labels = labels; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public BaseAudioRecordReader(List<String> labels) { |  | ||||||
|         this.labels = labels; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public BaseAudioRecordReader(boolean appendLabel) { |  | ||||||
|         this.appendLabel = appendLabel; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected abstract List<Writable> loadData(File file, InputStream inputStream) throws IOException; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(InputSplit split) throws IOException, InterruptedException { |  | ||||||
|         inputSplit = split; |  | ||||||
|         if (split instanceof BaseInputSplit) { |  | ||||||
|             URI[] locations = split.locations(); |  | ||||||
|             if (locations != null && locations.length >= 1) { |  | ||||||
|                 if (locations.length > 1) { |  | ||||||
|                     List<File> allFiles = new ArrayList<>(); |  | ||||||
|                     for (URI location : locations) { |  | ||||||
|                         File iter = new File(location); |  | ||||||
|                         if (iter.isDirectory()) { |  | ||||||
|                             Iterator<File> allFiles2 = FileUtils.iterateFiles(iter, null, true); |  | ||||||
|                             while (allFiles2.hasNext()) |  | ||||||
|                                 allFiles.add(allFiles2.next()); |  | ||||||
|                         } |  | ||||||
| 
 |  | ||||||
|                         else |  | ||||||
|                             allFiles.add(iter); |  | ||||||
|                     } |  | ||||||
| 
 |  | ||||||
|                     iter = allFiles.iterator(); |  | ||||||
|                 } else { |  | ||||||
|                     File curr = new File(locations[0]); |  | ||||||
|                     if (curr.isDirectory()) |  | ||||||
|                         iter = FileUtils.iterateFiles(curr, null, true); |  | ||||||
|                     else |  | ||||||
|                         iter = Collections.singletonList(curr).iterator(); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         else if (split instanceof InputStreamInputSplit) { |  | ||||||
|             record = new ArrayList<>(); |  | ||||||
|             InputStreamInputSplit split2 = (InputStreamInputSplit) split; |  | ||||||
|             InputStream is = split2.getIs(); |  | ||||||
|             URI[] locations = split2.locations(); |  | ||||||
|             if (appendLabel) { |  | ||||||
|                 Path path = Paths.get(locations[0]); |  | ||||||
|                 String parent = path.getParent().toString(); |  | ||||||
|                 record.add(new DoubleWritable(labels.indexOf(parent))); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             is.close(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { |  | ||||||
|         this.conf = conf; |  | ||||||
|         this.appendLabel = conf.getBoolean(APPEND_LABEL, false); |  | ||||||
|         this.labels = new ArrayList<>(conf.getStringCollection(LABELS)); |  | ||||||
|         initialize(split); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Writable> next() { |  | ||||||
|         if (iter != null) { |  | ||||||
|             File next = iter.next(); |  | ||||||
|             invokeListeners(next); |  | ||||||
|             try { |  | ||||||
|                 return loadData(next, null); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 throw new RuntimeException(e); |  | ||||||
|             } |  | ||||||
|         } else if (record != null) { |  | ||||||
|             hitImage = true; |  | ||||||
|             return record; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean hasNext() { |  | ||||||
|         if (iter != null) { |  | ||||||
|             return iter.hasNext(); |  | ||||||
|         } else if (record != null) { |  | ||||||
|             return !hitImage; |  | ||||||
|         } |  | ||||||
|         throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void close() throws IOException { |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setConf(Configuration conf) { |  | ||||||
|         this.conf = conf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Configuration getConf() { |  | ||||||
|         return conf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> getLabels() { |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void reset() { |  | ||||||
|         if (inputSplit == null) |  | ||||||
|             throw new UnsupportedOperationException("Cannot reset without first initializing"); |  | ||||||
|         try { |  | ||||||
|             initialize(inputSplit); |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException("Error during LineRecordReader reset", e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean resetSupported(){ |  | ||||||
|         if(inputSplit == null){ |  | ||||||
|             return false; |  | ||||||
|         } |  | ||||||
|         return inputSplit.resetSupported(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException { |  | ||||||
|         invokeListeners(uri); |  | ||||||
|         try { |  | ||||||
|             return loadData(null, dataInputStream); |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Record nextRecord() { |  | ||||||
|         return new org.datavec.api.records.impl.Record(next(), null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { |  | ||||||
|         throw new UnsupportedOperationException("Loading from metadata not yet implemented"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException { |  | ||||||
|         throw new UnsupportedOperationException("Loading from metadata not yet implemented"); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,76 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.recordreader; |  | ||||||
| 
 |  | ||||||
| import org.bytedeco.javacv.FFmpegFrameGrabber; |  | ||||||
| import org.bytedeco.javacv.Frame; |  | ||||||
| import org.datavec.api.writable.FloatWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.nio.FloatBuffer; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static org.bytedeco.ffmpeg.global.avutil.AV_SAMPLE_FMT_FLT; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Native audio file loader using FFmpeg. |  | ||||||
|  * |  | ||||||
|  * @author saudet |  | ||||||
|  */ |  | ||||||
| public class NativeAudioRecordReader extends BaseAudioRecordReader { |  | ||||||
| 
 |  | ||||||
|     public NativeAudioRecordReader() {} |  | ||||||
| 
 |  | ||||||
|     public NativeAudioRecordReader(boolean appendLabel, List<String> labels) { |  | ||||||
|         super(appendLabel, labels); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public NativeAudioRecordReader(List<String> labels) { |  | ||||||
|         super(labels); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public NativeAudioRecordReader(boolean appendLabel) { |  | ||||||
|         super(appendLabel); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected List<Writable> loadData(File file, InputStream inputStream) throws IOException { |  | ||||||
|         List<Writable> ret = new ArrayList<>(); |  | ||||||
|         try (FFmpegFrameGrabber grabber = inputStream != null ? new FFmpegFrameGrabber(inputStream) |  | ||||||
|                         : new FFmpegFrameGrabber(file.getAbsolutePath())) { |  | ||||||
|             grabber.setSampleFormat(AV_SAMPLE_FMT_FLT); |  | ||||||
|             grabber.start(); |  | ||||||
|             Frame frame; |  | ||||||
|             while ((frame = grabber.grab()) != null) { |  | ||||||
|                 while (frame.samples != null && frame.samples[0].hasRemaining()) { |  | ||||||
|                     for (int i = 0; i < frame.samples.length; i++) { |  | ||||||
|                         ret.add(new FloatWritable(((FloatBuffer) frame.samples[i]).get())); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,57 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio.recordreader; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.util.RecordUtils; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.audio.Wave; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Wav file loader |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class WavFileRecordReader extends BaseAudioRecordReader { |  | ||||||
| 
 |  | ||||||
|     public WavFileRecordReader() {} |  | ||||||
| 
 |  | ||||||
|     public WavFileRecordReader(boolean appendLabel, List<String> labels) { |  | ||||||
|         super(appendLabel, labels); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public WavFileRecordReader(List<String> labels) { |  | ||||||
|         super(labels); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public WavFileRecordReader(boolean appendLabel) { |  | ||||||
|         super(appendLabel); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected List<Writable> loadData(File file, InputStream inputStream) throws IOException { |  | ||||||
|         Wave wave = inputStream != null ? new Wave(inputStream) : new Wave(file.getAbsolutePath()); |  | ||||||
|         return RecordUtils.toRecord(wave.getNormalizedAmplitudes()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,51 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| package org.datavec.audio; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	public long getTimeoutMilliseconds() { |  | ||||||
| 		return 60000; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Set<Class<?>> getExclusions() { |  | ||||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) |  | ||||||
| 	    return new HashSet<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
| 	protected String getPackageName() { |  | ||||||
|     	return "org.datavec.audio"; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Class<?> getBaseClass() { |  | ||||||
|     	return BaseND4JTest.class; |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,68 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio; |  | ||||||
| 
 |  | ||||||
| import org.bytedeco.javacv.FFmpegFrameRecorder; |  | ||||||
| import org.bytedeco.javacv.Frame; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.api.split.FileSplit; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.audio.recordreader.NativeAudioRecordReader; |  | ||||||
| import org.junit.Ignore; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.nio.ShortBuffer; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_VORBIS; |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| import static org.junit.Assert.assertTrue; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author saudet |  | ||||||
|  */ |  | ||||||
| public class AudioReaderTest extends BaseND4JTest { |  | ||||||
|     @Ignore |  | ||||||
|     @Test |  | ||||||
|     public void testNativeAudioReader() throws Exception { |  | ||||||
|         File tempFile = File.createTempFile("testNativeAudioReader", ".ogg"); |  | ||||||
|         FFmpegFrameRecorder recorder = new FFmpegFrameRecorder(tempFile, 2); |  | ||||||
|         recorder.setAudioCodec(AV_CODEC_ID_VORBIS); |  | ||||||
|         recorder.setSampleRate(44100); |  | ||||||
|         recorder.start(); |  | ||||||
|         Frame audioFrame = new Frame(); |  | ||||||
|         ShortBuffer audioBuffer = ShortBuffer.allocate(64 * 1024); |  | ||||||
|         audioFrame.sampleRate = 44100; |  | ||||||
|         audioFrame.audioChannels = 2; |  | ||||||
|         audioFrame.samples = new ShortBuffer[] {audioBuffer}; |  | ||||||
|         recorder.record(audioFrame); |  | ||||||
|         recorder.stop(); |  | ||||||
|         recorder.release(); |  | ||||||
| 
 |  | ||||||
|         RecordReader reader = new NativeAudioRecordReader(); |  | ||||||
|         reader.initialize(new FileSplit(tempFile)); |  | ||||||
|         assertTrue(reader.hasNext()); |  | ||||||
|         List<Writable> record = reader.next(); |  | ||||||
|         assertEquals(audioBuffer.limit(), record.size()); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,69 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.audio; |  | ||||||
| 
 |  | ||||||
| import org.datavec.audio.dsp.FastFourierTransform; |  | ||||||
| import org.junit.Assert; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| public class TestFastFourierTransform extends BaseND4JTest { |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testFastFourierTransformComplex() { |  | ||||||
|         FastFourierTransform fft = new FastFourierTransform(); |  | ||||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; |  | ||||||
|         double[] frequencies = fft.getMagnitudes(amplitudes); |  | ||||||
| 
 |  | ||||||
|         Assert.assertEquals(2, frequencies.length); |  | ||||||
|         Assert.assertArrayEquals(new double[] {21.335, 18.513}, frequencies, 0.005); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testFastFourierTransformComplexLong() { |  | ||||||
|         FastFourierTransform fft = new FastFourierTransform(); |  | ||||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; |  | ||||||
|         double[] frequencies = fft.getMagnitudes(amplitudes, true); |  | ||||||
| 
 |  | ||||||
|         Assert.assertEquals(4, frequencies.length); |  | ||||||
|         Assert.assertArrayEquals(new double[] {21.335, 18.5132, 14.927, 7.527}, frequencies, 0.005); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testFastFourierTransformReal() { |  | ||||||
|         FastFourierTransform fft = new FastFourierTransform(); |  | ||||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; |  | ||||||
|         double[] frequencies = fft.getMagnitudes(amplitudes, false); |  | ||||||
| 
 |  | ||||||
|         Assert.assertEquals(4, frequencies.length); |  | ||||||
|         Assert.assertArrayEquals(new double[] {28.8, 2.107, 14.927, 19.874}, frequencies, 0.005); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testFastFourierTransformRealOddSize() { |  | ||||||
|         FastFourierTransform fft = new FastFourierTransform(); |  | ||||||
|         double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5}; |  | ||||||
|         double[] frequencies = fft.getMagnitudes(amplitudes, false); |  | ||||||
| 
 |  | ||||||
|         Assert.assertEquals(3, frequencies.length); |  | ||||||
|         Assert.assertArrayEquals(new double[] {24.2, 3.861, 16.876}, frequencies, 0.005); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,71 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-data</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-data-codec</artifactId> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-data-codec</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-api</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-data-image</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.jcodec</groupId> |  | ||||||
|             <artifactId>jcodec</artifactId> |  | ||||||
|             <version>0.1.5</version> |  | ||||||
|         </dependency> |  | ||||||
|         <!-- Do not depend on FFmpeg by default due to licensing concerns. --> |  | ||||||
|         <!-- |  | ||||||
|                 <dependency> |  | ||||||
|                     <groupId>org.bytedeco</groupId> |  | ||||||
|                     <artifactId>ffmpeg-platform</artifactId> |  | ||||||
|                     <version>${ffmpeg.version}-${javacpp-presets.version}</version> |  | ||||||
|                 </dependency> |  | ||||||
|         --> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,41 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.codec.format.input; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.formats.input.BaseInputFormat; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.api.split.InputSplit; |  | ||||||
| import org.datavec.codec.reader.CodecRecordReader; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class CodecInputFormat extends BaseInputFormat { |  | ||||||
|     @Override |  | ||||||
|     public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { |  | ||||||
|         RecordReader reader = new CodecRecordReader(); |  | ||||||
|         reader.initialize(conf, split); |  | ||||||
|         return reader; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,144 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.codec.reader; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.SequenceRecord; |  | ||||||
| import org.datavec.api.records.metadata.RecordMetaData; |  | ||||||
| import org.datavec.api.records.metadata.RecordMetaDataURI; |  | ||||||
| import org.datavec.api.records.reader.SequenceRecordReader; |  | ||||||
| import org.datavec.api.records.reader.impl.FileRecordReader; |  | ||||||
| import org.datavec.api.split.InputSplit; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| 
 |  | ||||||
| import java.io.DataInputStream; |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.net.URI; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public abstract class BaseCodecRecordReader extends FileRecordReader implements SequenceRecordReader { |  | ||||||
|     protected int startFrame = 0; |  | ||||||
|     protected int numFrames = -1; |  | ||||||
|     protected int totalFrames = -1; |  | ||||||
|     protected double framesPerSecond = -1; |  | ||||||
|     protected double videoLength = -1; |  | ||||||
|     protected int rows = 28, cols = 28; |  | ||||||
|     protected boolean ravel = false; |  | ||||||
| 
 |  | ||||||
|     public final static String NAME_SPACE = "org.datavec.codec.reader"; |  | ||||||
|     public final static String ROWS = NAME_SPACE + ".rows"; |  | ||||||
|     public final static String COLUMNS = NAME_SPACE + ".columns"; |  | ||||||
|     public final static String START_FRAME = NAME_SPACE + ".startframe"; |  | ||||||
|     public final static String TOTAL_FRAMES = NAME_SPACE + ".frames"; |  | ||||||
|     public final static String TIME_SLICE = NAME_SPACE + ".time"; |  | ||||||
|     public final static String RAVEL = NAME_SPACE + ".ravel"; |  | ||||||
|     public final static String VIDEO_DURATION = NAME_SPACE + ".duration"; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<List<Writable>> sequenceRecord() { |  | ||||||
|         URI next = locationsIterator.next(); |  | ||||||
| 
 |  | ||||||
|         try (InputStream s = streamCreatorFn.apply(next)){ |  | ||||||
|             return loadData(null, s); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException { |  | ||||||
|         return loadData(null, dataInputStream); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected abstract List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { |  | ||||||
|         setConf(conf); |  | ||||||
|         initialize(split); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Writable> next() { |  | ||||||
|         throw new UnsupportedOperationException("next() not supported for CodecRecordReader (use: sequenceRecord)"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException { |  | ||||||
|         throw new UnsupportedOperationException("record(URI,DataInputStream) not supported for CodecRecordReader"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setConf(Configuration conf) { |  | ||||||
|         super.setConf(conf); |  | ||||||
|         startFrame = conf.getInt(START_FRAME, 0); |  | ||||||
|         numFrames = conf.getInt(TOTAL_FRAMES, -1); |  | ||||||
|         rows = conf.getInt(ROWS, 28); |  | ||||||
|         cols = conf.getInt(COLUMNS, 28); |  | ||||||
|         framesPerSecond = conf.getFloat(TIME_SLICE, -1); |  | ||||||
|         videoLength = conf.getFloat(VIDEO_DURATION, -1); |  | ||||||
|         ravel = conf.getBoolean(RAVEL, false); |  | ||||||
|         totalFrames = conf.getInt(TOTAL_FRAMES, -1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Configuration getConf() { |  | ||||||
|         return super.getConf(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public SequenceRecord nextSequence() { |  | ||||||
|         URI next = locationsIterator.next(); |  | ||||||
| 
 |  | ||||||
|         List<List<Writable>> list; |  | ||||||
|         try (InputStream s = streamCreatorFn.apply(next)){ |  | ||||||
|             list = loadData(null, s); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|         return new org.datavec.api.records.impl.SequenceRecord(list, |  | ||||||
|                         new RecordMetaDataURI(next, CodecRecordReader.class)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException { |  | ||||||
|         return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException { |  | ||||||
|         List<SequenceRecord> out = new ArrayList<>(); |  | ||||||
|         for (RecordMetaData meta : recordMetaDatas) { |  | ||||||
|             try (InputStream s = streamCreatorFn.apply(meta.getURI())){ |  | ||||||
|                 List<List<Writable>> list = loadData(null, s); |  | ||||||
|                 out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return out; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,138 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.codec.reader; |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.compress.utils.IOUtils; |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.util.ndarray.RecordConverter; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.image.loader.ImageLoader; |  | ||||||
| import org.jcodec.api.FrameGrab; |  | ||||||
| import org.jcodec.api.JCodecException; |  | ||||||
| import org.jcodec.common.ByteBufferSeekableByteChannel; |  | ||||||
| import org.jcodec.common.NIOUtils; |  | ||||||
| import org.jcodec.common.SeekableByteChannel; |  | ||||||
| 
 |  | ||||||
| import java.awt.image.BufferedImage; |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.lang.reflect.Field; |  | ||||||
| import java.nio.ByteBuffer; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class CodecRecordReader extends BaseCodecRecordReader { |  | ||||||
| 
 |  | ||||||
|     private ImageLoader imageLoader; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setConf(Configuration conf) { |  | ||||||
|         super.setConf(conf); |  | ||||||
|         imageLoader = new ImageLoader(rows, cols); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException { |  | ||||||
|         SeekableByteChannel seekableByteChannel; |  | ||||||
|         if (inputStream != null) { |  | ||||||
|             //Reading video from DataInputStream: Need data from this stream in a SeekableByteChannel |  | ||||||
|             //Approach used here: load entire video into memory -> ByteBufferSeekableByteChanel |  | ||||||
|             byte[] data = IOUtils.toByteArray(inputStream); |  | ||||||
|             ByteBuffer bb = ByteBuffer.wrap(data); |  | ||||||
|             seekableByteChannel = new FixedByteBufferSeekableByteChannel(bb); |  | ||||||
|         } else { |  | ||||||
|             seekableByteChannel = NIOUtils.readableFileChannel(file); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         List<List<Writable>> record = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|         if (numFrames >= 1) { |  | ||||||
|             FrameGrab fg; |  | ||||||
|             try { |  | ||||||
|                 fg = new FrameGrab(seekableByteChannel); |  | ||||||
|                 if (startFrame != 0) |  | ||||||
|                     fg.seekToFramePrecise(startFrame); |  | ||||||
|             } catch (JCodecException e) { |  | ||||||
|                 throw new RuntimeException(e); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             for (int i = startFrame; i < startFrame + numFrames; i++) { |  | ||||||
|                 try { |  | ||||||
|                     BufferedImage grab = fg.getFrame(); |  | ||||||
|                     if (ravel) |  | ||||||
|                         record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab))); |  | ||||||
|                     else |  | ||||||
|                         record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab))); |  | ||||||
| 
 |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             if (framesPerSecond < 1) |  | ||||||
|                 throw new IllegalStateException("No frames or frame time intervals specified"); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             else { |  | ||||||
|                 for (double i = 0; i < videoLength; i += framesPerSecond) { |  | ||||||
|                     try { |  | ||||||
|                         BufferedImage grab = FrameGrab.getFrame(seekableByteChannel, i); |  | ||||||
|                         if (ravel) |  | ||||||
|                             record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab))); |  | ||||||
|                         else |  | ||||||
|                             record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab))); |  | ||||||
| 
 |  | ||||||
|                     } catch (Exception e) { |  | ||||||
|                         throw new RuntimeException(e); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return record; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */ |  | ||||||
|     private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel { |  | ||||||
|         private ByteBuffer backing; |  | ||||||
| 
 |  | ||||||
|         public FixedByteBufferSeekableByteChannel(ByteBuffer backing) { |  | ||||||
|             super(backing); |  | ||||||
|             try { |  | ||||||
|                 Field f = this.getClass().getSuperclass().getDeclaredField("maxPos"); |  | ||||||
|                 f.setAccessible(true); |  | ||||||
|                 f.set(this, backing.limit()); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 throw new RuntimeException(e); |  | ||||||
|             } |  | ||||||
|             this.backing = backing; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public int read(ByteBuffer dst) throws IOException { |  | ||||||
|             if (!backing.hasRemaining()) |  | ||||||
|                 return -1; |  | ||||||
|             return super.read(dst); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,82 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.codec.reader; |  | ||||||
| 
 |  | ||||||
| import org.bytedeco.javacv.FFmpegFrameGrabber; |  | ||||||
| import org.bytedeco.javacv.Frame; |  | ||||||
| import org.bytedeco.javacv.OpenCVFrameConverter; |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.util.ndarray.RecordConverter; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.image.loader.NativeImageLoader; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class NativeCodecRecordReader extends BaseCodecRecordReader { |  | ||||||
| 
 |  | ||||||
|     private OpenCVFrameConverter.ToMat converter; |  | ||||||
|     private NativeImageLoader imageLoader; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setConf(Configuration conf) { |  | ||||||
|         super.setConf(conf); |  | ||||||
|         converter = new OpenCVFrameConverter.ToMat(); |  | ||||||
|         imageLoader = new NativeImageLoader(rows, cols); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected List<List<Writable>> loadData(File file, InputStream inputStream) throws IOException { |  | ||||||
|         List<List<Writable>> record = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|         try (FFmpegFrameGrabber fg = |  | ||||||
|                         inputStream != null ? new FFmpegFrameGrabber(inputStream) : new FFmpegFrameGrabber(file)) { |  | ||||||
|             if (numFrames >= 1) { |  | ||||||
|                 fg.start(); |  | ||||||
|                 if (startFrame != 0) |  | ||||||
|                     fg.setFrameNumber(startFrame); |  | ||||||
| 
 |  | ||||||
|                 for (int i = startFrame; i < startFrame + numFrames; i++) { |  | ||||||
|                     Frame grab = fg.grabImage(); |  | ||||||
|                     record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab)))); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 if (framesPerSecond < 1) |  | ||||||
|                     throw new IllegalStateException("No frames or frame time intervals specified"); |  | ||||||
|                 else { |  | ||||||
|                     fg.start(); |  | ||||||
| 
 |  | ||||||
|                     for (double i = 0; i < videoLength; i += framesPerSecond) { |  | ||||||
|                         fg.setTimestamp(Math.round(i * 1000000L)); |  | ||||||
|                         Frame grab = fg.grabImage(); |  | ||||||
|                         record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab)))); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return record; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,46 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| package org.datavec.codec.reader; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Set<Class<?>> getExclusions() { |  | ||||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) |  | ||||||
| 	    return new HashSet<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
| 	protected String getPackageName() { |  | ||||||
|     	return "org.datavec.codec.reader"; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Class<?> getBaseClass() { |  | ||||||
|     	return BaseND4JTest.class; |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,212 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.codec.reader; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.SequenceRecord; |  | ||||||
| import org.datavec.api.records.metadata.RecordMetaData; |  | ||||||
| import org.datavec.api.records.reader.SequenceRecordReader; |  | ||||||
| import org.datavec.api.split.FileSplit; |  | ||||||
| import org.datavec.api.writable.ArrayWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.junit.Ignore; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.common.io.ClassPathResource; |  | ||||||
| 
 |  | ||||||
| import java.io.DataInputStream; |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.FileInputStream; |  | ||||||
| import java.util.Iterator; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| import static org.junit.Assert.assertTrue; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class CodecReaderTest { |  | ||||||
|     @Test |  | ||||||
|     public void testCodecReader() throws Exception { |  | ||||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); |  | ||||||
|         SequenceRecordReader reader = new CodecRecordReader(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.set(CodecRecordReader.RAVEL, "true"); |  | ||||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); |  | ||||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); |  | ||||||
|         conf.set(CodecRecordReader.ROWS, "80"); |  | ||||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); |  | ||||||
|         reader.initialize(new FileSplit(file)); |  | ||||||
|         reader.setConf(conf); |  | ||||||
|         assertTrue(reader.hasNext()); |  | ||||||
|         List<List<Writable>> record = reader.sequenceRecord(); |  | ||||||
|         //        System.out.println(record.size()); |  | ||||||
| 
 |  | ||||||
|         Iterator<List<Writable>> it = record.iterator(); |  | ||||||
|         List<Writable> first = it.next(); |  | ||||||
|         //        System.out.println(first); |  | ||||||
| 
 |  | ||||||
|         //Expected size: 80x46x3 |  | ||||||
|         assertEquals(1, first.size()); |  | ||||||
|         assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testCodecReaderMeta() throws Exception { |  | ||||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); |  | ||||||
|         SequenceRecordReader reader = new CodecRecordReader(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.set(CodecRecordReader.RAVEL, "true"); |  | ||||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); |  | ||||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); |  | ||||||
|         conf.set(CodecRecordReader.ROWS, "80"); |  | ||||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); |  | ||||||
|         reader.initialize(new FileSplit(file)); |  | ||||||
|         reader.setConf(conf); |  | ||||||
|         assertTrue(reader.hasNext()); |  | ||||||
|         List<List<Writable>> record = reader.sequenceRecord(); |  | ||||||
|         assertEquals(500, record.size()); //500 frames |  | ||||||
| 
 |  | ||||||
|         reader.reset(); |  | ||||||
|         SequenceRecord seqR = reader.nextSequence(); |  | ||||||
|         assertEquals(record, seqR.getSequenceRecord()); |  | ||||||
|         RecordMetaData meta = seqR.getMetaData(); |  | ||||||
|         //        System.out.println(meta); |  | ||||||
|         assertTrue(meta.getURI().toString().endsWith(file.getName())); |  | ||||||
| 
 |  | ||||||
|         SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta); |  | ||||||
|         assertEquals(seqR, fromMeta); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testViaDataInputStream() throws Exception { |  | ||||||
| 
 |  | ||||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); |  | ||||||
|         SequenceRecordReader reader = new CodecRecordReader(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.set(CodecRecordReader.RAVEL, "true"); |  | ||||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); |  | ||||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); |  | ||||||
|         conf.set(CodecRecordReader.ROWS, "80"); |  | ||||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); |  | ||||||
| 
 |  | ||||||
|         Configuration conf2 = new Configuration(conf); |  | ||||||
| 
 |  | ||||||
|         reader.initialize(new FileSplit(file)); |  | ||||||
|         reader.setConf(conf); |  | ||||||
|         assertTrue(reader.hasNext()); |  | ||||||
|         List<List<Writable>> expected = reader.sequenceRecord(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         SequenceRecordReader reader2 = new CodecRecordReader(); |  | ||||||
|         reader2.setConf(conf2); |  | ||||||
| 
 |  | ||||||
|         DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file)); |  | ||||||
|         List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream); |  | ||||||
| 
 |  | ||||||
|         assertEquals(expected, actual); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Ignore |  | ||||||
|     @Test |  | ||||||
|     public void testNativeCodecReader() throws Exception { |  | ||||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); |  | ||||||
|         SequenceRecordReader reader = new NativeCodecRecordReader(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.set(CodecRecordReader.RAVEL, "true"); |  | ||||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); |  | ||||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); |  | ||||||
|         conf.set(CodecRecordReader.ROWS, "80"); |  | ||||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); |  | ||||||
|         reader.initialize(new FileSplit(file)); |  | ||||||
|         reader.setConf(conf); |  | ||||||
|         assertTrue(reader.hasNext()); |  | ||||||
|         List<List<Writable>> record = reader.sequenceRecord(); |  | ||||||
|         //        System.out.println(record.size()); |  | ||||||
| 
 |  | ||||||
|         Iterator<List<Writable>> it = record.iterator(); |  | ||||||
|         List<Writable> first = it.next(); |  | ||||||
|         //        System.out.println(first); |  | ||||||
| 
 |  | ||||||
|         //Expected size: 80x46x3 |  | ||||||
|         assertEquals(1, first.size()); |  | ||||||
|         assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Ignore |  | ||||||
|     @Test |  | ||||||
|     public void testNativeCodecReaderMeta() throws Exception { |  | ||||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); |  | ||||||
|         SequenceRecordReader reader = new NativeCodecRecordReader(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.set(CodecRecordReader.RAVEL, "true"); |  | ||||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); |  | ||||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); |  | ||||||
|         conf.set(CodecRecordReader.ROWS, "80"); |  | ||||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); |  | ||||||
|         reader.initialize(new FileSplit(file)); |  | ||||||
|         reader.setConf(conf); |  | ||||||
|         assertTrue(reader.hasNext()); |  | ||||||
|         List<List<Writable>> record = reader.sequenceRecord(); |  | ||||||
|         assertEquals(500, record.size()); //500 frames |  | ||||||
| 
 |  | ||||||
|         reader.reset(); |  | ||||||
|         SequenceRecord seqR = reader.nextSequence(); |  | ||||||
|         assertEquals(record, seqR.getSequenceRecord()); |  | ||||||
|         RecordMetaData meta = seqR.getMetaData(); |  | ||||||
|         //        System.out.println(meta); |  | ||||||
|         assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4")); |  | ||||||
| 
 |  | ||||||
|         SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta); |  | ||||||
|         assertEquals(seqR, fromMeta); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Ignore |  | ||||||
|     @Test |  | ||||||
|     public void testNativeViaDataInputStream() throws Exception { |  | ||||||
| 
 |  | ||||||
|         File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); |  | ||||||
|         SequenceRecordReader reader = new NativeCodecRecordReader(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.set(CodecRecordReader.RAVEL, "true"); |  | ||||||
|         conf.set(CodecRecordReader.START_FRAME, "160"); |  | ||||||
|         conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); |  | ||||||
|         conf.set(CodecRecordReader.ROWS, "80"); |  | ||||||
|         conf.set(CodecRecordReader.COLUMNS, "46"); |  | ||||||
| 
 |  | ||||||
|         Configuration conf2 = new Configuration(conf); |  | ||||||
| 
 |  | ||||||
|         reader.initialize(new FileSplit(file)); |  | ||||||
|         reader.setConf(conf); |  | ||||||
|         assertTrue(reader.hasNext()); |  | ||||||
|         List<List<Writable>> expected = reader.sequenceRecord(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         SequenceRecordReader reader2 = new NativeCodecRecordReader(); |  | ||||||
|         reader2.setConf(conf2); |  | ||||||
| 
 |  | ||||||
|         DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file)); |  | ||||||
|         List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream); |  | ||||||
| 
 |  | ||||||
|         assertEquals(expected, actual); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,77 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-data</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-data-nlp</artifactId> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-data-nlp</name> |  | ||||||
| 
 |  | ||||||
|     <properties> |  | ||||||
|         <cleartk.version>2.0.0</cleartk.version> |  | ||||||
|     </properties> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-api</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-local</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.apache.commons</groupId> |  | ||||||
|             <artifactId>commons-lang3</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.cleartk</groupId> |  | ||||||
|             <artifactId>cleartk-snowball</artifactId> |  | ||||||
|             <version>${cleartk.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.cleartk</groupId> |  | ||||||
|             <artifactId>cleartk-opennlp-tools</artifactId> |  | ||||||
|             <version>${cleartk.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,237 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.annotator; |  | ||||||
| 
 |  | ||||||
| import opennlp.tools.postag.POSModel; |  | ||||||
| import opennlp.tools.postag.POSTaggerME; |  | ||||||
| import opennlp.uima.postag.POSModelResource; |  | ||||||
| import opennlp.uima.postag.POSModelResourceImpl; |  | ||||||
| import opennlp.uima.util.AnnotationComboIterator; |  | ||||||
| import opennlp.uima.util.AnnotationIteratorPair; |  | ||||||
| import opennlp.uima.util.AnnotatorUtil; |  | ||||||
| import opennlp.uima.util.UimaUtil; |  | ||||||
| import org.apache.uima.UimaContext; |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; |  | ||||||
| import org.apache.uima.cas.CAS; |  | ||||||
| import org.apache.uima.cas.Feature; |  | ||||||
| import org.apache.uima.cas.Type; |  | ||||||
| import org.apache.uima.cas.TypeSystem; |  | ||||||
| import org.apache.uima.cas.text.AnnotationFS; |  | ||||||
| import org.apache.uima.fit.component.CasAnnotator_ImplBase; |  | ||||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; |  | ||||||
| import org.apache.uima.fit.factory.ExternalResourceFactory; |  | ||||||
| import org.apache.uima.resource.ResourceAccessException; |  | ||||||
| import org.apache.uima.resource.ResourceInitializationException; |  | ||||||
| import org.apache.uima.util.Level; |  | ||||||
| import org.apache.uima.util.Logger; |  | ||||||
| import org.cleartk.token.type.Sentence; |  | ||||||
| import org.cleartk.token.type.Token; |  | ||||||
| import org.datavec.nlp.movingwindow.Util; |  | ||||||
| 
 |  | ||||||
| import java.util.Iterator; |  | ||||||
| import java.util.LinkedList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class PoStagger extends CasAnnotator_ImplBase { |  | ||||||
| 
 |  | ||||||
|     static { |  | ||||||
|         //UIMA logging |  | ||||||
|         Util.disableLogging(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private POSTaggerME posTagger; |  | ||||||
| 
 |  | ||||||
|     private Type sentenceType; |  | ||||||
| 
 |  | ||||||
|     private Type tokenType; |  | ||||||
| 
 |  | ||||||
|     private Feature posFeature; |  | ||||||
| 
 |  | ||||||
|     private Feature probabilityFeature; |  | ||||||
| 
 |  | ||||||
|     private UimaContext context; |  | ||||||
| 
 |  | ||||||
|     private Logger logger; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initializes a new instance. |  | ||||||
|      *  |  | ||||||
|      * Note: Use {@link #initialize(org.apache.uima.UimaContext) } to initialize this instance. Not use the |  | ||||||
|      * constructor. |  | ||||||
|      */ |  | ||||||
|     public PoStagger() { |  | ||||||
|         // must not be implemented ! |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initializes the current instance with the given context. |  | ||||||
|      * |  | ||||||
|      * Note: Do all initialization in this method, do not use the constructor. |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public void initialize(UimaContext context) throws ResourceInitializationException { |  | ||||||
| 
 |  | ||||||
|         super.initialize(context); |  | ||||||
| 
 |  | ||||||
|         this.context = context; |  | ||||||
| 
 |  | ||||||
|         this.logger = context.getLogger(); |  | ||||||
| 
 |  | ||||||
|         if (this.logger.isLoggable(Level.INFO)) { |  | ||||||
|             this.logger.log(Level.INFO, "Initializing the OpenNLP " + "Part of Speech annotator."); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         POSModel model; |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             POSModelResource modelResource = (POSModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER); |  | ||||||
| 
 |  | ||||||
|             model = modelResource.getModel(); |  | ||||||
|         } catch (ResourceAccessException e) { |  | ||||||
|             throw new ResourceInitializationException(e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         Integer beamSize = AnnotatorUtil.getOptionalIntegerParameter(context, UimaUtil.BEAM_SIZE_PARAMETER); |  | ||||||
| 
 |  | ||||||
|         if (beamSize == null) |  | ||||||
|             beamSize = POSTaggerME.DEFAULT_BEAM_SIZE; |  | ||||||
| 
 |  | ||||||
|         this.posTagger = new POSTaggerME(model, beamSize, 0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initializes the type system. |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException { |  | ||||||
| 
 |  | ||||||
|         // sentence type |  | ||||||
|         this.sentenceType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem, |  | ||||||
|                         UimaUtil.SENTENCE_TYPE_PARAMETER); |  | ||||||
| 
 |  | ||||||
|         // token type |  | ||||||
|         this.tokenType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem, |  | ||||||
|                         UimaUtil.TOKEN_TYPE_PARAMETER); |  | ||||||
| 
 |  | ||||||
|         // pos feature |  | ||||||
|         this.posFeature = AnnotatorUtil.getRequiredFeatureParameter(this.context, this.tokenType, |  | ||||||
|                         UimaUtil.POS_FEATURE_PARAMETER, CAS.TYPE_NAME_STRING); |  | ||||||
| 
 |  | ||||||
|         this.probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(this.context, this.tokenType, |  | ||||||
|                         UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Performs pos-tagging on the given tcas object. |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public synchronized void process(CAS tcas) { |  | ||||||
| 
 |  | ||||||
|         final AnnotationComboIterator comboIterator = |  | ||||||
|                         new AnnotationComboIterator(tcas, this.sentenceType, this.tokenType); |  | ||||||
| 
 |  | ||||||
|         for (AnnotationIteratorPair annotationIteratorPair : comboIterator) { |  | ||||||
| 
 |  | ||||||
|             final List<AnnotationFS> sentenceTokenAnnotationList = new LinkedList<AnnotationFS>(); |  | ||||||
| 
 |  | ||||||
|             final List<String> sentenceTokenList = new LinkedList<String>(); |  | ||||||
| 
 |  | ||||||
|             for (AnnotationFS tokenAnnotation : annotationIteratorPair.getSubIterator()) { |  | ||||||
| 
 |  | ||||||
|                 sentenceTokenAnnotationList.add(tokenAnnotation); |  | ||||||
| 
 |  | ||||||
|                 sentenceTokenList.add(tokenAnnotation.getCoveredText()); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             final List<String> posTags = this.posTagger.tag(sentenceTokenList); |  | ||||||
| 
 |  | ||||||
|             double posProbabilities[] = null; |  | ||||||
| 
 |  | ||||||
|             if (this.probabilityFeature != null) { |  | ||||||
|                 posProbabilities = this.posTagger.probs(); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             final Iterator<String> posTagIterator = posTags.iterator(); |  | ||||||
|             final Iterator<AnnotationFS> sentenceTokenIterator = sentenceTokenAnnotationList.iterator(); |  | ||||||
| 
 |  | ||||||
|             int index = 0; |  | ||||||
|             while (posTagIterator.hasNext() && sentenceTokenIterator.hasNext()) { |  | ||||||
|                 final String posTag = posTagIterator.next(); |  | ||||||
|                 final AnnotationFS tokenAnnotation = sentenceTokenIterator.next(); |  | ||||||
| 
 |  | ||||||
|                 tokenAnnotation.setStringValue(this.posFeature, posTag); |  | ||||||
| 
 |  | ||||||
|                 if (posProbabilities != null) { |  | ||||||
|                     tokenAnnotation.setDoubleValue(this.posFeature, posProbabilities[index]); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 index++; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             // log tokens with pos |  | ||||||
|             if (this.logger.isLoggable(Level.FINER)) { |  | ||||||
| 
 |  | ||||||
|                 final StringBuilder sentenceWithPos = new StringBuilder(); |  | ||||||
| 
 |  | ||||||
|                 sentenceWithPos.append("\""); |  | ||||||
| 
 |  | ||||||
|                 for (final Iterator<AnnotationFS> it = sentenceTokenAnnotationList.iterator(); it.hasNext();) { |  | ||||||
|                     final AnnotationFS token = it.next(); |  | ||||||
|                     sentenceWithPos.append(token.getCoveredText()); |  | ||||||
|                     sentenceWithPos.append('\\'); |  | ||||||
|                     sentenceWithPos.append(token.getStringValue(this.posFeature)); |  | ||||||
|                     sentenceWithPos.append(' '); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 // delete last whitespace |  | ||||||
|                 if (sentenceWithPos.length() > 1) // not 0 because it contains already the " char |  | ||||||
|                     sentenceWithPos.setLength(sentenceWithPos.length() - 1); |  | ||||||
| 
 |  | ||||||
|                 sentenceWithPos.append("\""); |  | ||||||
| 
 |  | ||||||
|                 this.logger.log(Level.FINER, sentenceWithPos.toString()); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Releases allocated resources. |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public void destroy() { |  | ||||||
|         this.posTagger = null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException { |  | ||||||
|         String modelPath = String.format("/models/%s-pos-maxent.bin", languageCode); |  | ||||||
|         return AnalysisEngineFactory.createEngineDescription(PoStagger.class, UimaUtil.MODEL_PARAMETER, |  | ||||||
|                         ExternalResourceFactory.createExternalResourceDescription(POSModelResourceImpl.class, |  | ||||||
|                                         PoStagger.class.getResource(modelPath).toString()), |  | ||||||
|                         UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), UimaUtil.TOKEN_TYPE_PARAMETER, |  | ||||||
|                         Token.class.getName(), UimaUtil.POS_FEATURE_PARAMETER, "pos"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,52 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.annotator; |  | ||||||
| 
 |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; |  | ||||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; |  | ||||||
| import org.apache.uima.jcas.JCas; |  | ||||||
| import org.apache.uima.resource.ResourceInitializationException; |  | ||||||
| import org.cleartk.util.ParamUtil; |  | ||||||
| import org.datavec.nlp.movingwindow.Util; |  | ||||||
| 
 |  | ||||||
| public class SentenceAnnotator extends org.cleartk.opennlp.tools.SentenceAnnotator { |  | ||||||
| 
 |  | ||||||
|     static { |  | ||||||
|         //UIMA logging |  | ||||||
|         Util.disableLogging(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { |  | ||||||
|         return AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.class, PARAM_SENTENCE_MODEL_PATH, |  | ||||||
|                         ParamUtil.getParameterValue(PARAM_SENTENCE_MODEL_PATH, "/models/en-sent.bin"), |  | ||||||
|                         PARAM_WINDOW_CLASS_NAMES, ParamUtil.getParameterValue(PARAM_WINDOW_CLASS_NAMES, null)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public synchronized void process(JCas jCas) throws AnalysisEngineProcessException { |  | ||||||
|         super.process(jCas); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,58 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.annotator; |  | ||||||
| 
 |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; |  | ||||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; |  | ||||||
| import org.apache.uima.jcas.JCas; |  | ||||||
| import org.apache.uima.resource.ResourceInitializationException; |  | ||||||
| import org.cleartk.snowball.SnowballStemmer; |  | ||||||
| import org.cleartk.token.type.Token; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class StemmerAnnotator extends SnowballStemmer<Token> { |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { |  | ||||||
|         return getDescription("English"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngineDescription getDescription(String language) throws ResourceInitializationException { |  | ||||||
|         return AnalysisEngineFactory.createEngineDescription(StemmerAnnotator.class, SnowballStemmer.PARAM_STEMMER_NAME, |  | ||||||
|                         language); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @SuppressWarnings("unchecked") |  | ||||||
|     @Override |  | ||||||
|     public synchronized void process(JCas jCas) throws AnalysisEngineProcessException { |  | ||||||
|         super.process(jCas); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setStem(Token token, String stem) { |  | ||||||
|         token.setStem(stem); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,70 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.annotator; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import opennlp.uima.tokenize.TokenizerModelResourceImpl; |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineDescription; |  | ||||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; |  | ||||||
| import org.apache.uima.fit.factory.ExternalResourceFactory; |  | ||||||
| import org.apache.uima.resource.ResourceInitializationException; |  | ||||||
| import org.cleartk.opennlp.tools.Tokenizer; |  | ||||||
| import org.cleartk.token.type.Sentence; |  | ||||||
| import org.cleartk.token.type.Token; |  | ||||||
| import org.datavec.nlp.movingwindow.Util; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.ConcurrentTokenizer; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Overrides OpenNLP tokenizer to be thread safe |  | ||||||
|  */ |  | ||||||
| public class TokenizerAnnotator extends Tokenizer { |  | ||||||
| 
 |  | ||||||
|     static { |  | ||||||
|         //UIMA logging |  | ||||||
|         Util.disableLogging(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException { |  | ||||||
|         String modelPath = String.format("/models/%s-token.bin", languageCode); |  | ||||||
|         return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class, |  | ||||||
|                         opennlp.uima.util.UimaUtil.MODEL_PARAMETER, |  | ||||||
|                         ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class, |  | ||||||
|                                         ConcurrentTokenizer.class.getResource(modelPath).toString()), |  | ||||||
|                         opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), |  | ||||||
|                         opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { |  | ||||||
|         String modelPath = String.format("/models/%s-token.bin", "en"); |  | ||||||
|         return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class, |  | ||||||
|                         opennlp.uima.util.UimaUtil.MODEL_PARAMETER, |  | ||||||
|                         ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class, |  | ||||||
|                                         ConcurrentTokenizer.class.getResource(modelPath).toString()), |  | ||||||
|                         opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), |  | ||||||
|                         opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,41 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.input; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.formats.input.BaseInputFormat; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.api.split.InputSplit; |  | ||||||
| import org.datavec.nlp.reader.TfidfRecordReader; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class TextInputFormat extends BaseInputFormat { |  | ||||||
|     @Override |  | ||||||
|     public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { |  | ||||||
|         RecordReader reader = new TfidfRecordReader(); |  | ||||||
|         reader.initialize(conf, split); |  | ||||||
|         return reader; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,148 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.metadata; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.common.primitives.Counter; |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.nlp.vectorizer.TextVectorizer; |  | ||||||
| import org.nd4j.common.util.MathUtils; |  | ||||||
| import org.nd4j.common.util.Index; |  | ||||||
| 
 |  | ||||||
| public class DefaultVocabCache implements VocabCache { |  | ||||||
| 
 |  | ||||||
|     private Counter<String> wordFrequencies = new Counter<>(); |  | ||||||
|     private Counter<String> docFrequencies = new Counter<>(); |  | ||||||
|     private int minWordFrequency; |  | ||||||
|     private Index vocabWords = new Index(); |  | ||||||
|     private double numDocs = 0; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Instantiate with a given min word frequency |  | ||||||
|      * @param minWordFrequency |  | ||||||
|      */ |  | ||||||
|     public DefaultVocabCache(int minWordFrequency) { |  | ||||||
|         this.minWordFrequency = minWordFrequency; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /* |  | ||||||
|      * Constructor for use with initialize() |  | ||||||
|      */ |  | ||||||
|     public DefaultVocabCache() { |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void incrementNumDocs(double by) { |  | ||||||
|         numDocs += by; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public double numDocs() { |  | ||||||
|         return numDocs; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String wordAt(int i) { |  | ||||||
|         return vocabWords.get(i).toString(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int wordIndex(String word) { |  | ||||||
|         return vocabWords.indexOf(word); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(Configuration conf) { |  | ||||||
|         minWordFrequency = conf.getInt(TextVectorizer.MIN_WORD_FREQUENCY, 5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public double wordFrequency(String word) { |  | ||||||
|         return wordFrequencies.getCount(word); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int minWordFrequency() { |  | ||||||
|         return minWordFrequency; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Index vocabWords() { |  | ||||||
|         return vocabWords; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void incrementDocCount(String word) { |  | ||||||
|         incrementDocCount(word, 1.0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void incrementDocCount(String word, double by) { |  | ||||||
|         docFrequencies.incrementCount(word, by); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void incrementCount(String word) { |  | ||||||
|         incrementCount(word, 1.0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void incrementCount(String word, double by) { |  | ||||||
|         wordFrequencies.incrementCount(word, by); |  | ||||||
|         if (wordFrequencies.getCount(word) >= minWordFrequency && vocabWords.indexOf(word) < 0) |  | ||||||
|             vocabWords.add(word); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public double idf(String word) { |  | ||||||
|         return docFrequencies.getCount(word); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public double tfidf(String word, double frequency, boolean smoothIdf) { |  | ||||||
|         double tf = tf((int) frequency); |  | ||||||
|         double docFreq = docFrequencies.getCount(word); |  | ||||||
| 
 |  | ||||||
|         double idf = idf(numDocs, docFreq, smoothIdf); |  | ||||||
|         double tfidf = MathUtils.tfidf(tf, idf); |  | ||||||
|         return tfidf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double idf(double totalDocs, double numTimesWordAppearedInADocument, boolean smooth) { |  | ||||||
|         if(smooth){ |  | ||||||
|             return Math.log((1 + totalDocs) / (1 + numTimesWordAppearedInADocument)) + 1.0; |  | ||||||
|         } else { |  | ||||||
|             return Math.log(totalDocs / numTimesWordAppearedInADocument) + 1.0; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static double tf(int count) { |  | ||||||
|         return count; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getMinWordFrequency() { |  | ||||||
|         return minWordFrequency; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setMinWordFrequency(int minWordFrequency) { |  | ||||||
|         this.minWordFrequency = minWordFrequency; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,121 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.metadata; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.nd4j.common.util.Index; |  | ||||||
| 
 |  | ||||||
| public interface VocabCache { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Increment the number of documents |  | ||||||
|      * @param by |  | ||||||
|      */ |  | ||||||
|     void incrementNumDocs(double by); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Number of documents |  | ||||||
|      * @return the number of documents |  | ||||||
|      */ |  | ||||||
|     double numDocs(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns a word in the vocab at a particular index |  | ||||||
|      * @param i the index to get |  | ||||||
|      * @return the word at that index in the vocab |  | ||||||
|      */ |  | ||||||
|     String wordAt(int i); |  | ||||||
| 
 |  | ||||||
|     int wordIndex(String word); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Configuration for initializing |  | ||||||
|      * @param conf the configuration to initialize with |  | ||||||
|      */ |  | ||||||
|     void initialize(Configuration conf); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the word frequency for a word |  | ||||||
|      * @param word the word to get frequency for |  | ||||||
|      * @return the frequency for a given word |  | ||||||
|      */ |  | ||||||
|     double wordFrequency(String word); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The min word frequency |  | ||||||
|      * needed to be included in the vocab |  | ||||||
|      * (default 5) |  | ||||||
|      * @return the min word frequency to |  | ||||||
|      * be included in the vocab |  | ||||||
|      */ |  | ||||||
|     int minWordFrequency(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * All of the vocab words (ordered) |  | ||||||
|      * note that these are not all the possible tokens |  | ||||||
|      * @return the list of vocab words |  | ||||||
|      */ |  | ||||||
|     Index vocabWords(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Increment the doc count for a word by 1 |  | ||||||
|      * @param word the word to increment the count for |  | ||||||
|      */ |  | ||||||
|     void incrementDocCount(String word); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Increment the document count for a particular word |  | ||||||
|      * @param word the word to increment the count for |  | ||||||
|      * @param by the amount to increment by |  | ||||||
|      */ |  | ||||||
|     void incrementDocCount(String word, double by); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Increment a word count by 1 |  | ||||||
|      * @param word the word to increment the count for |  | ||||||
|      */ |  | ||||||
|     void incrementCount(String word); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Increment count for a word |  | ||||||
|      * @param word the word to increment the count for |  | ||||||
|      * @param by the amount to increment by |  | ||||||
|      */ |  | ||||||
|     void incrementCount(String word, double by); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Number of documents word has occurred in |  | ||||||
|      * @param word the word to get the idf for |  | ||||||
|      */ |  | ||||||
|     double idf(String word); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Calculate the tfidf of the word given the document frequency |  | ||||||
|      * @param word the word to get frequency for |  | ||||||
|      * @param frequency the frequency |  | ||||||
|      * @return the tfidf for a word |  | ||||||
|      */ |  | ||||||
|     double tfidf(String word, double frequency, boolean smoothIdf); |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,125 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.movingwindow; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.lang3.StringUtils; |  | ||||||
| import org.nd4j.common.base.Preconditions; |  | ||||||
| import org.nd4j.common.collection.MultiDimensionalMap; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class ContextLabelRetriever { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>"; |  | ||||||
|     private static String END_LABEL = "</([A-Za-z]+|\\d+)>"; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private ContextLabelRetriever() {} |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns a stripped sentence with the indices of words |  | ||||||
|      * with certain kinds of labels. |  | ||||||
|      * |  | ||||||
|      * @param sentence the sentence to process |  | ||||||
|      * @return a pair of a post processed sentence |  | ||||||
|      * with labels stripped and the spans of |  | ||||||
|      * the labels |  | ||||||
|      */ |  | ||||||
|     public static Pair<String, MultiDimensionalMap<Integer, Integer, String>> stringWithLabels(String sentence, |  | ||||||
|                                                        TokenizerFactory tokenizerFactory) { |  | ||||||
|         MultiDimensionalMap<Integer, Integer, String> map = MultiDimensionalMap.newHashBackedMap(); |  | ||||||
|         Tokenizer t = tokenizerFactory.create(sentence); |  | ||||||
|         List<String> currTokens = new ArrayList<>(); |  | ||||||
|         String currLabel = null; |  | ||||||
|         String endLabel = null; |  | ||||||
|         List<Pair<String, List<String>>> tokensWithSameLabel = new ArrayList<>(); |  | ||||||
|         while (t.hasMoreTokens()) { |  | ||||||
|             String token = t.nextToken(); |  | ||||||
|             if (token.matches(BEGIN_LABEL)) { |  | ||||||
|                 currLabel = token; |  | ||||||
| 
 |  | ||||||
|                 //no labels; add these as NONE and begin the new label |  | ||||||
|                 if (!currTokens.isEmpty()) { |  | ||||||
|                     tokensWithSameLabel.add(new Pair<>("NONE", (List<String>) new ArrayList<>(currTokens))); |  | ||||||
|                     currTokens.clear(); |  | ||||||
| 
 |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|             } else if (token.matches(END_LABEL)) { |  | ||||||
|                 if (currLabel == null) |  | ||||||
|                     throw new IllegalStateException("Found an ending label with no matching begin label"); |  | ||||||
|                 endLabel = token; |  | ||||||
|             } else |  | ||||||
|                 currTokens.add(token); |  | ||||||
| 
 |  | ||||||
|             if (currLabel != null && endLabel != null) { |  | ||||||
|                 currLabel = currLabel.replaceAll("[<>/]", ""); |  | ||||||
|                 endLabel = endLabel.replaceAll("[<>/]", ""); |  | ||||||
|                 Preconditions.checkState(!currLabel.isEmpty(), "Current label is empty!"); |  | ||||||
|                 Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!"); |  | ||||||
|                 Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s", |  | ||||||
|                         currLabel, endLabel); |  | ||||||
| 
 |  | ||||||
|                 tokensWithSameLabel.add(new Pair<>(currLabel, (List<String>) new ArrayList<>(currTokens))); |  | ||||||
|                 currTokens.clear(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|                 //clear out the tokens |  | ||||||
|                 currLabel = null; |  | ||||||
|                 endLabel = null; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         //no labels; add these as NONE and begin the new label |  | ||||||
|         if (!currTokens.isEmpty()) { |  | ||||||
|             tokensWithSameLabel.add(new Pair<>("none", (List<String>) new ArrayList<>(currTokens))); |  | ||||||
|             currTokens.clear(); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         //now join the output |  | ||||||
|         StringBuilder strippedSentence = new StringBuilder(); |  | ||||||
|         for (Pair<String, List<String>> tokensWithLabel : tokensWithSameLabel) { |  | ||||||
|             String joinedSentence = StringUtils.join(tokensWithLabel.getSecond(), " "); |  | ||||||
|             //spaces between separate parts of the sentence |  | ||||||
|             if (!(strippedSentence.length() < 1)) |  | ||||||
|                 strippedSentence.append(" "); |  | ||||||
|             strippedSentence.append(joinedSentence); |  | ||||||
|             int begin = strippedSentence.toString().indexOf(joinedSentence); |  | ||||||
|             int end = begin + joinedSentence.length(); |  | ||||||
|             map.put(begin, end, tokensWithLabel.getFirst()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         return new Pair<>(strippedSentence.toString(), map); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,60 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.movingwindow; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.nd4j.common.primitives.Counter; |  | ||||||
| 
 |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.logging.Level; |  | ||||||
| import java.util.logging.Logger; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class Util { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns a thread safe counter |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static Counter<String> parallelCounter() { |  | ||||||
|         return new Counter<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static boolean matchesAnyStopWord(List<String> stopWords, String word) { |  | ||||||
|         for (String s : stopWords) |  | ||||||
|             if (s.equalsIgnoreCase(word)) |  | ||||||
|                 return true; |  | ||||||
|         return false; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static Level disableLogging() { |  | ||||||
|         Logger logger = Logger.getLogger("org.apache.uima"); |  | ||||||
|         while (logger.getLevel() == null) { |  | ||||||
|             logger = logger.getParent(); |  | ||||||
|         } |  | ||||||
|         Level level = logger.getLevel(); |  | ||||||
|         logger.setLevel(Level.OFF); |  | ||||||
|         return level; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,177 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.movingwindow; |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.lang3.StringUtils; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collection; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class Window implements Serializable { |  | ||||||
|     /** |  | ||||||
|      *  |  | ||||||
|      */ |  | ||||||
|     private static final long serialVersionUID = 6359906393699230579L; |  | ||||||
|     private List<String> words; |  | ||||||
|     private String label = "NONE"; |  | ||||||
|     private boolean beginLabel; |  | ||||||
|     private boolean endLabel; |  | ||||||
|     private int median; |  | ||||||
|     private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>"; |  | ||||||
|     private static String END_LABEL = "</([A-Z]+|\\d+)>"; |  | ||||||
|     private int begin, end; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Creates a window with a context of size 3 |  | ||||||
|      * @param words a collection of strings of size 3 |  | ||||||
|      */ |  | ||||||
|     public Window(Collection<String> words, int begin, int end) { |  | ||||||
|         this(words, 5, begin, end); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String asTokens() { |  | ||||||
|         return StringUtils.join(words, " "); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initialize a window with the given size |  | ||||||
|      * @param words the words to use  |  | ||||||
|      * @param windowSize the size of the window |  | ||||||
|      * @param begin the begin index for the window |  | ||||||
|      * @param end the end index for the window |  | ||||||
|      */ |  | ||||||
|     public Window(Collection<String> words, int windowSize, int begin, int end) { |  | ||||||
|         if (words == null) |  | ||||||
|             throw new IllegalArgumentException("Words must be a list of size 3"); |  | ||||||
| 
 |  | ||||||
|         this.words = new ArrayList<>(words); |  | ||||||
|         int windowSize1 = windowSize; |  | ||||||
|         this.begin = begin; |  | ||||||
|         this.end = end; |  | ||||||
|         initContext(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private void initContext() { |  | ||||||
|         int median = (int) Math.floor(words.size() / 2); |  | ||||||
|         List<String> begin = words.subList(0, median); |  | ||||||
|         List<String> after = words.subList(median + 1, words.size()); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         for (String s : begin) { |  | ||||||
|             if (s.matches(BEGIN_LABEL)) { |  | ||||||
|                 this.label = s.replaceAll("(<|>)", "").replace("/", ""); |  | ||||||
|                 beginLabel = true; |  | ||||||
|             } else if (s.matches(END_LABEL)) { |  | ||||||
|                 endLabel = true; |  | ||||||
|                 this.label = s.replaceAll("(<|>|/)", "").replace("/", ""); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         for (String s1 : after) { |  | ||||||
| 
 |  | ||||||
|             if (s1.matches(BEGIN_LABEL)) { |  | ||||||
|                 this.label = s1.replaceAll("(<|>)", "").replace("/", ""); |  | ||||||
|                 beginLabel = true; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             if (s1.matches(END_LABEL)) { |  | ||||||
|                 endLabel = true; |  | ||||||
|                 this.label = s1.replaceAll("(<|>)", ""); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         this.median = median; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return words.toString(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public List<String> getWords() { |  | ||||||
|         return words; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setWords(List<String> words) { |  | ||||||
|         this.words = words; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String getWord(int i) { |  | ||||||
|         return words.get(i); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String getFocusWord() { |  | ||||||
|         return words.get(median); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isBeginLabel() { |  | ||||||
|         return !label.equals("NONE") && beginLabel; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isEndLabel() { |  | ||||||
|         return !label.equals("NONE") && endLabel; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public String getLabel() { |  | ||||||
|         return label.replace("/", ""); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getWindowSize() { |  | ||||||
|         return words.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getMedian() { |  | ||||||
|         return median; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setLabel(String label) { |  | ||||||
|         this.label = label; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getBegin() { |  | ||||||
|         return begin; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setBegin(int begin) { |  | ||||||
|         this.begin = begin; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getEnd() { |  | ||||||
|         return end; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setEnd(int end) { |  | ||||||
|         this.end = end; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,188 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.movingwindow; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.lang3.StringUtils; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; |  | ||||||
| 
 |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.StringTokenizer; |  | ||||||
| 
 |  | ||||||
| public class Windows { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructs a list of window of size windowSize. |  | ||||||
|      * Note that padding for each window is created as well. |  | ||||||
|      * @param words the words to tokenize and construct windows from |  | ||||||
|      * @param windowSize the window size to generate |  | ||||||
|      * @return the list of windows for the tokenized string |  | ||||||
|      */ |  | ||||||
|     public static List<Window> windows(InputStream words, int windowSize) { |  | ||||||
|         Tokenizer tokenizer = new DefaultStreamTokenizer(words); |  | ||||||
|         List<String> list = new ArrayList<>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) |  | ||||||
|             list.add(tokenizer.nextToken()); |  | ||||||
|         return windows(list, windowSize); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructs a list of window of size windowSize. |  | ||||||
|      * Note that padding for each window is created as well. |  | ||||||
|      * @param words the words to tokenize and construct windows from |  | ||||||
|      * @param tokenizerFactory tokenizer factory to use |  | ||||||
|      * @param windowSize the window size to generate |  | ||||||
|      * @return the list of windows for the tokenized string |  | ||||||
|      */ |  | ||||||
|     public static List<Window> windows(InputStream words, TokenizerFactory tokenizerFactory, int windowSize) { |  | ||||||
|         Tokenizer tokenizer = tokenizerFactory.create(words); |  | ||||||
|         List<String> list = new ArrayList<>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) |  | ||||||
|             list.add(tokenizer.nextToken()); |  | ||||||
| 
 |  | ||||||
|         if (list.isEmpty()) |  | ||||||
|             throw new IllegalStateException("No tokens found for windows"); |  | ||||||
| 
 |  | ||||||
|         return windows(list, windowSize); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructs a list of window of size windowSize. |  | ||||||
|      * Note that padding for each window is created as well. |  | ||||||
|      * @param words the words to tokenize and construct windows from |  | ||||||
|      * @param windowSize the window size to generate |  | ||||||
|      * @return the list of windows for the tokenized string |  | ||||||
|      */ |  | ||||||
|     public static List<Window> windows(String words, int windowSize) { |  | ||||||
|         StringTokenizer tokenizer = new StringTokenizer(words); |  | ||||||
|         List<String> list = new ArrayList<String>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) |  | ||||||
|             list.add(tokenizer.nextToken()); |  | ||||||
|         return windows(list, windowSize); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructs a list of window of size windowSize. |  | ||||||
|      * Note that padding for each window is created as well. |  | ||||||
|      * @param words the words to tokenize and construct windows from |  | ||||||
|      * @param tokenizerFactory tokenizer factory to use |  | ||||||
|      * @param windowSize the window size to generate |  | ||||||
|      * @return the list of windows for the tokenized string |  | ||||||
|      */ |  | ||||||
|     public static List<Window> windows(String words, TokenizerFactory tokenizerFactory, int windowSize) { |  | ||||||
|         Tokenizer tokenizer = tokenizerFactory.create(words); |  | ||||||
|         List<String> list = new ArrayList<>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) |  | ||||||
|             list.add(tokenizer.nextToken()); |  | ||||||
| 
 |  | ||||||
|         if (list.isEmpty()) |  | ||||||
|             throw new IllegalStateException("No tokens found for windows"); |  | ||||||
| 
 |  | ||||||
|         return windows(list, windowSize); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructs a list of window of size windowSize. |  | ||||||
|      * Note that padding for each window is created as well. |  | ||||||
|      * @param words the words to tokenize and construct windows from |  | ||||||
|      * @return the list of windows for the tokenized string |  | ||||||
|      */ |  | ||||||
|     public static List<Window> windows(String words) { |  | ||||||
|         StringTokenizer tokenizer = new StringTokenizer(words); |  | ||||||
|         List<String> list = new ArrayList<String>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) |  | ||||||
|             list.add(tokenizer.nextToken()); |  | ||||||
|         return windows(list, 5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructs a list of window of size windowSize. |  | ||||||
|      * Note that padding for each window is created as well. |  | ||||||
|      * @param words the words to tokenize and construct windows from |  | ||||||
|      * @param tokenizerFactory tokenizer factory to use |  | ||||||
|      * @return the list of windows for the tokenized string |  | ||||||
|      */ |  | ||||||
|     public static List<Window> windows(String words, TokenizerFactory tokenizerFactory) { |  | ||||||
|         Tokenizer tokenizer = tokenizerFactory.create(words); |  | ||||||
|         List<String> list = new ArrayList<>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) |  | ||||||
|             list.add(tokenizer.nextToken()); |  | ||||||
|         return windows(list, 5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Creates a sliding window from text |  | ||||||
|      * @param windowSize the window size to use |  | ||||||
|      * @param wordPos the position of the word to center |  | ||||||
|      * @param sentence the sentence to createComplex a window for |  | ||||||
|      * @return a window based on the given sentence |  | ||||||
|      */ |  | ||||||
|     public static Window windowForWordInPosition(int windowSize, int wordPos, List<String> sentence) { |  | ||||||
|         List<String> window = new ArrayList<>(); |  | ||||||
|         List<String> onlyTokens = new ArrayList<>(); |  | ||||||
|         int contextSize = (int) Math.floor((windowSize - 1) / 2); |  | ||||||
| 
 |  | ||||||
|         for (int i = wordPos - contextSize; i <= wordPos + contextSize; i++) { |  | ||||||
|             if (i < 0) |  | ||||||
|                 window.add("<s>"); |  | ||||||
|             else if (i >= sentence.size()) |  | ||||||
|                 window.add("</s>"); |  | ||||||
|             else { |  | ||||||
|                 onlyTokens.add(sentence.get(i)); |  | ||||||
|                 window.add(sentence.get(i)); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         String wholeSentence = StringUtils.join(sentence); |  | ||||||
|         String window2 = StringUtils.join(onlyTokens); |  | ||||||
|         int begin = wholeSentence.indexOf(window2); |  | ||||||
|         int end = begin + window2.length(); |  | ||||||
|         return new Window(window, begin, end); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Constructs a list of window of size windowSize |  | ||||||
|      * @param words the words to  construct windows from |  | ||||||
|      * @return the list of windows for the tokenized string |  | ||||||
|      */ |  | ||||||
|     public static List<Window> windows(List<String> words, int windowSize) { |  | ||||||
| 
 |  | ||||||
|         List<Window> ret = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < words.size(); i++) |  | ||||||
|             ret.add(windowForWordInPosition(windowSize, i, words)); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,189 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.reader; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.Record; |  | ||||||
| import org.datavec.api.records.metadata.RecordMetaData; |  | ||||||
| import org.datavec.api.records.metadata.RecordMetaDataURI; |  | ||||||
| import org.datavec.api.records.reader.impl.FileRecordReader; |  | ||||||
| import org.datavec.api.split.InputSplit; |  | ||||||
| import org.datavec.api.vector.Vectorizer; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.nlp.vectorizer.TfidfVectorizer; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| public class TfidfRecordReader extends FileRecordReader { |  | ||||||
|     private TfidfVectorizer tfidfVectorizer; |  | ||||||
|     private List<Record> records = new ArrayList<>(); |  | ||||||
|     private Iterator<Record> recordIter; |  | ||||||
|     private int numFeatures; |  | ||||||
|     private boolean initialized = false; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(InputSplit split) throws IOException, InterruptedException { |  | ||||||
|         initialize(new Configuration(), split); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { |  | ||||||
|         super.initialize(conf, split); |  | ||||||
|         //train  a new one since it hasn't been specified |  | ||||||
|         if (tfidfVectorizer == null) { |  | ||||||
|             tfidfVectorizer = new TfidfVectorizer(); |  | ||||||
|             tfidfVectorizer.initialize(conf); |  | ||||||
| 
 |  | ||||||
|             //clear out old strings |  | ||||||
|             records.clear(); |  | ||||||
| 
 |  | ||||||
|             INDArray ret = tfidfVectorizer.fitTransform(this, new Vectorizer.RecordCallBack() { |  | ||||||
|                 @Override |  | ||||||
|                 public void onRecord(Record fullRecord) { |  | ||||||
|                     records.add(fullRecord); |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
| 
 |  | ||||||
|             //cache the number of features used for each document |  | ||||||
|             numFeatures = ret.columns(); |  | ||||||
|             recordIter = records.iterator(); |  | ||||||
|         } else { |  | ||||||
|             records = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|             //the record reader has 2 phases, we are skipping the |  | ||||||
|             //document frequency phase and just using the super() to get the file contents |  | ||||||
|             //and pass it to the already existing vectorizer. |  | ||||||
|             while (super.hasNext()) { |  | ||||||
|                 Record fileContents = super.nextRecord(); |  | ||||||
|                 INDArray transform = tfidfVectorizer.transform(fileContents); |  | ||||||
| 
 |  | ||||||
|                 org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record( |  | ||||||
|                                 new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))), |  | ||||||
|                                 new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class)); |  | ||||||
| 
 |  | ||||||
|                 if (appendLabel) |  | ||||||
|                     record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1)); |  | ||||||
| 
 |  | ||||||
|                 records.add(record); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             recordIter = records.iterator(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         this.initialized = true; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void reset() { |  | ||||||
|         if (inputSplit == null) |  | ||||||
|             throw new UnsupportedOperationException("Cannot reset without first initializing"); |  | ||||||
|         recordIter = records.iterator(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Record nextRecord() { |  | ||||||
|         if (recordIter == null) |  | ||||||
|             return super.nextRecord(); |  | ||||||
|         return recordIter.next(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Writable> next() { |  | ||||||
|         return nextRecord().getRecord(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean hasNext() { |  | ||||||
|         //we aren't done vectorizing yet |  | ||||||
|         if (recordIter == null) |  | ||||||
|             return super.hasNext(); |  | ||||||
|         return recordIter.hasNext(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void close() throws IOException { |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setConf(Configuration conf) { |  | ||||||
|         this.conf = conf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Configuration getConf() { |  | ||||||
|         return conf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public TfidfVectorizer getTfidfVectorizer() { |  | ||||||
|         return tfidfVectorizer; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) { |  | ||||||
|         if (initialized) { |  | ||||||
|             throw new IllegalArgumentException( |  | ||||||
|                             "Setting TfidfVectorizer after TfidfRecordReader initialization doesn't have an effect"); |  | ||||||
|         } |  | ||||||
|         this.tfidfVectorizer = tfidfVectorizer; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumFeatures() { |  | ||||||
|         return numFeatures; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void shuffle() { |  | ||||||
|         this.shuffle(new Random()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void shuffle(Random random) { |  | ||||||
|         Collections.shuffle(this.records, random); |  | ||||||
|         this.reset(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { |  | ||||||
|         return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException { |  | ||||||
|         List<Record> out = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|         for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) { |  | ||||||
|             INDArray transform = tfidfVectorizer.transform(fileContents); |  | ||||||
| 
 |  | ||||||
|             org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record( |  | ||||||
|                             new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))), |  | ||||||
|                             new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class)); |  | ||||||
| 
 |  | ||||||
|             if (appendLabel) |  | ||||||
|                 record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1)); |  | ||||||
|             out.add(record); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return out; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| @ -1,44 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.stopwords; |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.io.IOUtils; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class StopWords { |  | ||||||
| 
 |  | ||||||
|     private static List<String> stopWords; |  | ||||||
| 
 |  | ||||||
|     @SuppressWarnings("unchecked") |  | ||||||
|     public static List<String> getStopWords() { |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             if (stopWords == null) |  | ||||||
|                 stopWords = IOUtils.readLines(StopWords.class.getResourceAsStream("/stopwords")); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|         return stopWords; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,125 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer; |  | ||||||
| 
 |  | ||||||
| import opennlp.tools.tokenize.TokenizerME; |  | ||||||
| import opennlp.tools.tokenize.TokenizerModel; |  | ||||||
| import opennlp.tools.util.Span; |  | ||||||
| import opennlp.uima.tokenize.AbstractTokenizer; |  | ||||||
| import opennlp.uima.tokenize.TokenizerModelResource; |  | ||||||
| import opennlp.uima.util.AnnotatorUtil; |  | ||||||
| import opennlp.uima.util.UimaUtil; |  | ||||||
| import org.apache.uima.UimaContext; |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; |  | ||||||
| import org.apache.uima.cas.CAS; |  | ||||||
| import org.apache.uima.cas.Feature; |  | ||||||
| import org.apache.uima.cas.TypeSystem; |  | ||||||
| import org.apache.uima.cas.text.AnnotationFS; |  | ||||||
| import org.apache.uima.resource.ResourceAccessException; |  | ||||||
| import org.apache.uima.resource.ResourceInitializationException; |  | ||||||
| 
 |  | ||||||
| public class ConcurrentTokenizer extends AbstractTokenizer { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The OpenNLP tokenizer. |  | ||||||
|      */ |  | ||||||
|     private TokenizerME tokenizer; |  | ||||||
| 
 |  | ||||||
|     private Feature probabilityFeature; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public synchronized void process(CAS cas) throws AnalysisEngineProcessException { |  | ||||||
|         super.process(cas); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|        * Initializes a new instance. |  | ||||||
|        * |  | ||||||
|        * Note: Use {@link #initialize(UimaContext) } to initialize  |  | ||||||
|        * this instance. Not use the constructor. |  | ||||||
|        */ |  | ||||||
|     public ConcurrentTokenizer() { |  | ||||||
|         super("OpenNLP Tokenizer"); |  | ||||||
| 
 |  | ||||||
|         // must not be implemented ! |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initializes the current instance with the given context. |  | ||||||
|      *  |  | ||||||
|      * Note: Do all initialization in this method, do not use the constructor. |  | ||||||
|      */ |  | ||||||
|     public void initialize(UimaContext context) throws ResourceInitializationException { |  | ||||||
| 
 |  | ||||||
|         super.initialize(context); |  | ||||||
| 
 |  | ||||||
|         TokenizerModel model; |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             TokenizerModelResource modelResource = |  | ||||||
|                             (TokenizerModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER); |  | ||||||
| 
 |  | ||||||
|             model = modelResource.getModel(); |  | ||||||
|         } catch (ResourceAccessException e) { |  | ||||||
|             throw new ResourceInitializationException(e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         tokenizer = new TokenizerME(model); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initializes the type system. |  | ||||||
|      */ |  | ||||||
|     public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException { |  | ||||||
| 
 |  | ||||||
|         super.typeSystemInit(typeSystem); |  | ||||||
| 
 |  | ||||||
|         probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(context, tokenType, |  | ||||||
|                         UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Span[] tokenize(CAS cas, AnnotationFS sentence) { |  | ||||||
|         return tokenizer.tokenizePos(sentence.getCoveredText()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) { |  | ||||||
|         // if interest |  | ||||||
|         if (probabilityFeature != null) { |  | ||||||
|             double tokenProbabilties[] = tokenizer.getTokenProbabilities(); |  | ||||||
| 
 |  | ||||||
|             for (int i = 0; i < tokenAnnotations.length; i++) { |  | ||||||
|                 tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Releases allocated resources. |  | ||||||
|      */ |  | ||||||
|     public void destroy() { |  | ||||||
|         // dereference model to allow garbage collection  |  | ||||||
|         tokenizer = null; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| @ -1,107 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import java.io.*; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Tokenizer based on the {@link java.io.StreamTokenizer} |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| public class DefaultStreamTokenizer implements Tokenizer { |  | ||||||
| 
 |  | ||||||
|     private StreamTokenizer streamTokenizer; |  | ||||||
|     private TokenPreProcess tokenPreProcess; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public DefaultStreamTokenizer(InputStream is) { |  | ||||||
|         Reader r = new BufferedReader(new InputStreamReader(is)); |  | ||||||
|         streamTokenizer = new StreamTokenizer(r); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean hasMoreTokens() { |  | ||||||
|         if (streamTokenizer.ttype != StreamTokenizer.TT_EOF) { |  | ||||||
|             try { |  | ||||||
|                 streamTokenizer.nextToken(); |  | ||||||
|             } catch (IOException e1) { |  | ||||||
|                 throw new RuntimeException(e1); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return streamTokenizer.ttype != StreamTokenizer.TT_EOF && streamTokenizer.ttype != -1; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int countTokens() { |  | ||||||
|         return getTokens().size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String nextToken() { |  | ||||||
|         StringBuilder sb = new StringBuilder(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         if (streamTokenizer.ttype == StreamTokenizer.TT_WORD) { |  | ||||||
|             sb.append(streamTokenizer.sval); |  | ||||||
|         } else if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) { |  | ||||||
|             sb.append(streamTokenizer.nval); |  | ||||||
|         } else if (streamTokenizer.ttype == StreamTokenizer.TT_EOL) { |  | ||||||
|             try { |  | ||||||
|                 while (streamTokenizer.ttype == StreamTokenizer.TT_EOL) |  | ||||||
|                     streamTokenizer.nextToken(); |  | ||||||
|             } catch (IOException e) { |  | ||||||
|                 throw new RuntimeException(e); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         else if (hasMoreTokens()) |  | ||||||
|             return nextToken(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         String ret = sb.toString(); |  | ||||||
| 
 |  | ||||||
|         if (tokenPreProcess != null) |  | ||||||
|             ret = tokenPreProcess.preProcess(ret); |  | ||||||
|         return ret; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> getTokens() { |  | ||||||
|         List<String> tokens = new ArrayList<>(); |  | ||||||
|         while (hasMoreTokens()) { |  | ||||||
|             tokens.add(nextToken()); |  | ||||||
|         } |  | ||||||
|         return tokens; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { |  | ||||||
|         this.tokenPreProcess = tokenPreProcessor; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,75 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.StringTokenizer; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Default tokenizer |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class DefaultTokenizer implements Tokenizer { |  | ||||||
| 
 |  | ||||||
|     public DefaultTokenizer(String tokens) { |  | ||||||
|         tokenizer = new StringTokenizer(tokens); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private StringTokenizer tokenizer; |  | ||||||
|     private TokenPreProcess tokenPreProcess; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean hasMoreTokens() { |  | ||||||
|         return tokenizer.hasMoreTokens(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int countTokens() { |  | ||||||
|         return tokenizer.countTokens(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String nextToken() { |  | ||||||
|         String base = tokenizer.nextToken(); |  | ||||||
|         if (tokenPreProcess != null) |  | ||||||
|             base = tokenPreProcess.preProcess(base); |  | ||||||
|         return base; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> getTokens() { |  | ||||||
|         List<String> tokens = new ArrayList<>(); |  | ||||||
|         while (hasMoreTokens()) { |  | ||||||
|             tokens.add(nextToken()); |  | ||||||
|         } |  | ||||||
|         return tokens; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { |  | ||||||
|         this.tokenPreProcess = tokenPreProcessor; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,136 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer; |  | ||||||
| 
 |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngine; |  | ||||||
| import org.apache.uima.cas.CAS; |  | ||||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; |  | ||||||
| import org.apache.uima.fit.util.JCasUtil; |  | ||||||
| import org.cleartk.token.type.Sentence; |  | ||||||
| import org.cleartk.token.type.Token; |  | ||||||
| import org.datavec.nlp.annotator.PoStagger; |  | ||||||
| import org.datavec.nlp.annotator.SentenceAnnotator; |  | ||||||
| import org.datavec.nlp.annotator.StemmerAnnotator; |  | ||||||
| import org.datavec.nlp.annotator.TokenizerAnnotator; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collection; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class PosUimaTokenizer implements Tokenizer { |  | ||||||
| 
 |  | ||||||
|     private static AnalysisEngine engine; |  | ||||||
|     private List<String> tokens; |  | ||||||
|     private Collection<String> allowedPosTags; |  | ||||||
|     private int index; |  | ||||||
|     private static CAS cas; |  | ||||||
| 
 |  | ||||||
|     public PosUimaTokenizer(String tokens, AnalysisEngine engine, Collection<String> allowedPosTags) { |  | ||||||
|         if (engine == null) |  | ||||||
|             PosUimaTokenizer.engine = engine; |  | ||||||
|         this.allowedPosTags = allowedPosTags; |  | ||||||
|         this.tokens = new ArrayList<>(); |  | ||||||
|         try { |  | ||||||
|             if (cas == null) |  | ||||||
|                 cas = engine.newCAS(); |  | ||||||
| 
 |  | ||||||
|             cas.reset(); |  | ||||||
|             cas.setDocumentText(tokens); |  | ||||||
|             PosUimaTokenizer.engine.process(cas); |  | ||||||
|             for (Sentence s : JCasUtil.select(cas.getJCas(), Sentence.class)) { |  | ||||||
|                 for (Token t : JCasUtil.selectCovered(Token.class, s)) { |  | ||||||
|                     //add NONE for each invalid token |  | ||||||
|                     if (valid(t)) |  | ||||||
|                         if (t.getLemma() != null) |  | ||||||
|                             this.tokens.add(t.getLemma()); |  | ||||||
|                         else if (t.getStem() != null) |  | ||||||
|                             this.tokens.add(t.getStem()); |  | ||||||
|                         else |  | ||||||
|                             this.tokens.add(t.getCoveredText()); |  | ||||||
|                     else |  | ||||||
|                         this.tokens.add("NONE"); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private boolean valid(Token token) { |  | ||||||
|         String check = token.getCoveredText(); |  | ||||||
|         if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>")) |  | ||||||
|             return false; |  | ||||||
|         else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos())) |  | ||||||
|             return false; |  | ||||||
|         return true; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean hasMoreTokens() { |  | ||||||
|         return index < tokens.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int countTokens() { |  | ||||||
|         return tokens.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String nextToken() { |  | ||||||
|         String ret = tokens.get(index); |  | ||||||
|         index++; |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> getTokens() { |  | ||||||
|         List<String> tokens = new ArrayList<String>(); |  | ||||||
|         while (hasMoreTokens()) { |  | ||||||
|             tokens.add(nextToken()); |  | ||||||
|         } |  | ||||||
|         return tokens; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngine defaultAnalysisEngine() { |  | ||||||
|         try { |  | ||||||
|             return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription( |  | ||||||
|                             SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(), |  | ||||||
|                             PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English"))); |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { |  | ||||||
|         // TODO Auto-generated method stub |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,34 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public interface TokenPreProcess { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Pre process a token |  | ||||||
|      * @param token the token to pre process |  | ||||||
|      * @return the preprocessed token |  | ||||||
|      */ |  | ||||||
|     String preProcess(String token); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,61 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer; |  | ||||||
| 
 |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public interface Tokenizer { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * An iterator for tracking whether |  | ||||||
|      * more tokens are left in the iterator not |  | ||||||
|      * @return whether there is anymore tokens |  | ||||||
|      * to iterate over |  | ||||||
|      */ |  | ||||||
|     boolean hasMoreTokens(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The number of tokens in the tokenizer |  | ||||||
|      * @return the number of tokens |  | ||||||
|      */ |  | ||||||
|     int countTokens(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The next token (word usually) in the string |  | ||||||
|      * @return the next token in the string if any |  | ||||||
|      */ |  | ||||||
|     String nextToken(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns a list of all the tokens |  | ||||||
|      * @return a list of all the tokens |  | ||||||
|      */ |  | ||||||
|     List<String> getTokens(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Set the token pre process |  | ||||||
|      * @param tokenPreProcessor the token pre processor to set |  | ||||||
|      */ |  | ||||||
|     void setTokenPreProcessor(TokenPreProcess tokenPreProcessor); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,123 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer; |  | ||||||
| 
 |  | ||||||
| import org.apache.uima.cas.CAS; |  | ||||||
| import org.apache.uima.fit.util.JCasUtil; |  | ||||||
| import org.cleartk.token.type.Token; |  | ||||||
| import org.datavec.nlp.uima.UimaResource; |  | ||||||
| import org.slf4j.Logger; |  | ||||||
| import org.slf4j.LoggerFactory; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collection; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Tokenizer based on the passed in analysis engine |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| public class UimaTokenizer implements Tokenizer { |  | ||||||
| 
 |  | ||||||
|     private List<String> tokens; |  | ||||||
|     private int index; |  | ||||||
|     private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class); |  | ||||||
|     private boolean checkForLabel; |  | ||||||
|     private TokenPreProcess tokenPreProcessor; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public UimaTokenizer(String tokens, UimaResource resource, boolean checkForLabel) { |  | ||||||
| 
 |  | ||||||
|         this.checkForLabel = checkForLabel; |  | ||||||
|         this.tokens = new ArrayList<>(); |  | ||||||
|         try { |  | ||||||
|             CAS cas = resource.process(tokens); |  | ||||||
| 
 |  | ||||||
|             Collection<Token> tokenList = JCasUtil.select(cas.getJCas(), Token.class); |  | ||||||
| 
 |  | ||||||
|             for (Token t : tokenList) { |  | ||||||
| 
 |  | ||||||
|                 if (!checkForLabel || valid(t.getCoveredText())) |  | ||||||
|                     if (t.getLemma() != null) |  | ||||||
|                         this.tokens.add(t.getLemma()); |  | ||||||
|                     else if (t.getStem() != null) |  | ||||||
|                         this.tokens.add(t.getStem()); |  | ||||||
|                     else |  | ||||||
|                         this.tokens.add(t.getCoveredText()); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             resource.release(cas); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             log.error("",e); |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private boolean valid(String check) { |  | ||||||
|         if (check.matches("<[A-Z]+>") || check.matches("</[A-Z]+>")) |  | ||||||
|             return false; |  | ||||||
|         return true; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean hasMoreTokens() { |  | ||||||
|         return index < tokens.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int countTokens() { |  | ||||||
|         return tokens.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String nextToken() { |  | ||||||
|         String ret = tokens.get(index); |  | ||||||
|         index++; |  | ||||||
|         if (tokenPreProcessor != null) { |  | ||||||
|             ret = tokenPreProcessor.preProcess(ret); |  | ||||||
|         } |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> getTokens() { |  | ||||||
|         List<String> tokens = new ArrayList<>(); |  | ||||||
|         while (hasMoreTokens()) { |  | ||||||
|             tokens.add(nextToken()); |  | ||||||
|         } |  | ||||||
|         return tokens; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { |  | ||||||
|         this.tokenPreProcessor = tokenPreProcessor; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,47 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer.preprocessor; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Gets rid of endings: |  | ||||||
|  * |  | ||||||
|  *    ed,ing, ly, s, . |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class EndingPreProcessor implements TokenPreProcess { |  | ||||||
|     @Override |  | ||||||
|     public String preProcess(String token) { |  | ||||||
|         if (token.endsWith("s") && !token.endsWith("ss")) |  | ||||||
|             token = token.substring(0, token.length() - 1); |  | ||||||
|         if (token.endsWith(".")) |  | ||||||
|             token = token.substring(0, token.length() - 1); |  | ||||||
|         if (token.endsWith("ed")) |  | ||||||
|             token = token.substring(0, token.length() - 2); |  | ||||||
|         if (token.endsWith("ing")) |  | ||||||
|             token = token.substring(0, token.length() - 3); |  | ||||||
|         if (token.endsWith("ly")) |  | ||||||
|             token = token.substring(0, token.length() - 2); |  | ||||||
|         return token; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,30 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizer.preprocessor; |  | ||||||
| 
 |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| 
 |  | ||||||
| public class LowerCasePreProcessor implements TokenPreProcess { |  | ||||||
|     @Override |  | ||||||
|     public String preProcess(String token) { |  | ||||||
|         return token.toLowerCase(); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,60 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizerfactory; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.DefaultTokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| 
 |  | ||||||
| import java.io.InputStream; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Default tokenizer based on string tokenizer or stream tokenizer |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class DefaultTokenizerFactory implements TokenizerFactory { |  | ||||||
| 
 |  | ||||||
|     private TokenPreProcess tokenPreProcess; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Tokenizer create(String toTokenize) { |  | ||||||
|         DefaultTokenizer t = new DefaultTokenizer(toTokenize); |  | ||||||
|         t.setTokenPreProcessor(tokenPreProcess); |  | ||||||
|         return t; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Tokenizer create(InputStream toTokenize) { |  | ||||||
|         Tokenizer t = new DefaultStreamTokenizer(toTokenize); |  | ||||||
|         t.setTokenPreProcessor(tokenPreProcess); |  | ||||||
|         return t; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setTokenPreProcessor(TokenPreProcess preProcessor) { |  | ||||||
|         this.tokenPreProcess = preProcessor; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,85 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizerfactory; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngine; |  | ||||||
| import org.datavec.nlp.annotator.PoStagger; |  | ||||||
| import org.datavec.nlp.annotator.SentenceAnnotator; |  | ||||||
| import org.datavec.nlp.annotator.StemmerAnnotator; |  | ||||||
| import org.datavec.nlp.annotator.TokenizerAnnotator; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.PosUimaTokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| 
 |  | ||||||
| import java.io.InputStream; |  | ||||||
| import java.util.Collection; |  | ||||||
| 
 |  | ||||||
| import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine; |  | ||||||
| import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription; |  | ||||||
| 
 |  | ||||||
| public class PosUimaTokenizerFactory implements TokenizerFactory { |  | ||||||
| 
 |  | ||||||
|     private AnalysisEngine tokenizer; |  | ||||||
|     private Collection<String> allowedPoSTags; |  | ||||||
|     private TokenPreProcess tokenPreProcess; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public PosUimaTokenizerFactory(Collection<String> allowedPoSTags) { |  | ||||||
|         this(defaultAnalysisEngine(), allowedPoSTags); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public PosUimaTokenizerFactory(AnalysisEngine tokenizer, Collection<String> allowedPosTags) { |  | ||||||
|         this.tokenizer = tokenizer; |  | ||||||
|         this.allowedPoSTags = allowedPosTags; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static AnalysisEngine defaultAnalysisEngine() { |  | ||||||
|         try { |  | ||||||
|             return createEngine(createEngineDescription(SentenceAnnotator.getDescription(), |  | ||||||
|                             TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"), |  | ||||||
|                             StemmerAnnotator.getDescription("English"))); |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Tokenizer create(String toTokenize) { |  | ||||||
|         PosUimaTokenizer t = new PosUimaTokenizer(toTokenize, tokenizer, allowedPoSTags); |  | ||||||
|         t.setTokenPreProcessor(tokenPreProcess); |  | ||||||
|         return t; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Tokenizer create(InputStream toTokenize) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setTokenPreProcessor(TokenPreProcess preProcessor) { |  | ||||||
|         this.tokenPreProcess = preProcessor; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,62 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizerfactory; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonTypeInfo; |  | ||||||
| 
 |  | ||||||
| import java.io.InputStream; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Generates a tokenizer for a given string |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") |  | ||||||
| public interface TokenizerFactory { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The tokenizer to createComplex |  | ||||||
|      * @param toTokenize the string to createComplex the tokenizer with |  | ||||||
|      * @return the new tokenizer |  | ||||||
|      */ |  | ||||||
|     Tokenizer create(String toTokenize); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create a tokenizer based on an input stream |  | ||||||
|      * @param toTokenize |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     Tokenizer create(InputStream toTokenize); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Sets a token pre processor to be used |  | ||||||
|      * with every tokenizer |  | ||||||
|      * @param preProcessor the token pre processor to use |  | ||||||
|      */ |  | ||||||
|     void setTokenPreProcessor(TokenPreProcess preProcessor); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,138 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.tokenization.tokenizerfactory; |  | ||||||
| 
 |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngine; |  | ||||||
| import org.apache.uima.fit.factory.AnalysisEngineFactory; |  | ||||||
| import org.apache.uima.resource.ResourceInitializationException; |  | ||||||
| import org.datavec.nlp.annotator.SentenceAnnotator; |  | ||||||
| import org.datavec.nlp.annotator.TokenizerAnnotator; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.UimaTokenizer; |  | ||||||
| import org.datavec.nlp.uima.UimaResource; |  | ||||||
| 
 |  | ||||||
| import java.io.InputStream; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * Uses a uima {@link AnalysisEngine} to  |  | ||||||
|  * tokenize text. |  | ||||||
|  * |  | ||||||
|  * |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| public class UimaTokenizerFactory implements TokenizerFactory { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private UimaResource uimaResource; |  | ||||||
|     private boolean checkForLabel; |  | ||||||
|     private static AnalysisEngine defaultAnalysisEngine; |  | ||||||
|     private TokenPreProcess preProcess; |  | ||||||
| 
 |  | ||||||
|     public UimaTokenizerFactory() throws ResourceInitializationException { |  | ||||||
|         this(defaultAnalysisEngine(), true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public UimaTokenizerFactory(UimaResource resource) { |  | ||||||
|         this(resource, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public UimaTokenizerFactory(AnalysisEngine tokenizer) { |  | ||||||
|         this(tokenizer, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public UimaTokenizerFactory(UimaResource resource, boolean checkForLabel) { |  | ||||||
|         this.uimaResource = resource; |  | ||||||
|         this.checkForLabel = checkForLabel; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public UimaTokenizerFactory(boolean checkForLabel) throws ResourceInitializationException { |  | ||||||
|         this(defaultAnalysisEngine(), checkForLabel); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public UimaTokenizerFactory(AnalysisEngine tokenizer, boolean checkForLabel) { |  | ||||||
|         super(); |  | ||||||
|         this.checkForLabel = checkForLabel; |  | ||||||
|         try { |  | ||||||
|             this.uimaResource = new UimaResource(tokenizer); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Tokenizer create(String toTokenize) { |  | ||||||
|         if (toTokenize == null || toTokenize.isEmpty()) |  | ||||||
|             throw new IllegalArgumentException("Unable to proceed; on sentence to tokenize"); |  | ||||||
|         Tokenizer ret = new UimaTokenizer(toTokenize, uimaResource, checkForLabel); |  | ||||||
|         ret.setTokenPreProcessor(preProcess); |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public UimaResource getUimaResource() { |  | ||||||
|         return uimaResource; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Creates a tokenization,/stemming pipeline |  | ||||||
|      * @return a tokenization/stemming pipeline |  | ||||||
|      */ |  | ||||||
|     public static AnalysisEngine defaultAnalysisEngine() { |  | ||||||
|         try { |  | ||||||
|             if (defaultAnalysisEngine == null) |  | ||||||
| 
 |  | ||||||
|                 defaultAnalysisEngine = AnalysisEngineFactory.createEngine( |  | ||||||
|                                 AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.getDescription(), |  | ||||||
|                                                 TokenizerAnnotator.getDescription())); |  | ||||||
| 
 |  | ||||||
|             return defaultAnalysisEngine; |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Tokenizer create(InputStream toTokenize) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setTokenPreProcessor(TokenPreProcess preProcessor) { |  | ||||||
|         this.preProcess = preProcessor; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,69 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.Transform; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonTypeInfo; |  | ||||||
| 
 |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") |  | ||||||
| public interface BagOfWordsTransform extends Transform { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The output shape of the transform (usually 1 x number of words) |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     long[] outputShape(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The vocab words in the transform. |  | ||||||
|      * This is the words that were accumulated |  | ||||||
|      * when building a vocabulary. |  | ||||||
|      * (This is generally associated with some form of |  | ||||||
|      * mininmum words frequency scanning to build a vocab |  | ||||||
|      * you then map on to a list of vocab words as a list) |  | ||||||
|      * @return the vocab words for the transform |  | ||||||
|      */ |  | ||||||
|     List<String> vocabWords(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Transform for a list of tokens |  | ||||||
|      * that are objects. This is to allow loose |  | ||||||
|      * typing for tokens that are unique (non string) |  | ||||||
|      * @param tokens the token objects to transform |  | ||||||
|      * @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array) |  | ||||||
|      */ |  | ||||||
|     INDArray transformFromObject(List<List<Object>> tokens); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Transform for a list of tokens |  | ||||||
|      * that are {@link Writable} (Generally {@link org.datavec.api.writable.Text} |  | ||||||
|      * @param tokens the token objects to transform |  | ||||||
|      * @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array) |  | ||||||
|      */ |  | ||||||
|     INDArray transformFrom(List<List<Writable>> tokens); |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,24 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| public class BaseWordMapTransform { |  | ||||||
| } |  | ||||||
| @ -1,146 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.EqualsAndHashCode; |  | ||||||
| import org.datavec.api.transform.metadata.ColumnMetaData; |  | ||||||
| import org.datavec.api.transform.metadata.NDArrayMetaData; |  | ||||||
| import org.datavec.api.transform.transform.BaseColumnTransform; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataType; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonCreator; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonInclude; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; |  | ||||||
| 
 |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.HashSet; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.Set; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @EqualsAndHashCode(callSuper = true) |  | ||||||
| @JsonInclude(JsonInclude.Include.NON_NULL) |  | ||||||
| @JsonIgnoreProperties({"gazeteer"}) |  | ||||||
| public class GazeteerTransform extends BaseColumnTransform implements BagOfWordsTransform { |  | ||||||
| 
 |  | ||||||
|     private String newColumnName; |  | ||||||
|     private List<String> wordList; |  | ||||||
|     private Set<String> gazeteer; |  | ||||||
| 
 |  | ||||||
|     @JsonCreator |  | ||||||
|     public GazeteerTransform(@JsonProperty("columnName") String columnName, |  | ||||||
|                              @JsonProperty("newColumnName")String newColumnName, |  | ||||||
|                              @JsonProperty("wordList") List<String> wordList) { |  | ||||||
|         super(columnName); |  | ||||||
|         this.newColumnName = newColumnName; |  | ||||||
|         this.wordList = wordList; |  | ||||||
|         this.gazeteer = new HashSet<>(wordList); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { |  | ||||||
|         return new NDArrayMetaData(newName,new long[]{wordList.size()}); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Writable map(Writable columnWritable) { |  | ||||||
|        throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Object mapSequence(Object sequence) { |  | ||||||
|         List<List<Object>> sequenceInput = (List<List<Object>>) sequence; |  | ||||||
|         INDArray ret = Nd4j.create(DataType.FLOAT, wordList.size()); |  | ||||||
| 
 |  | ||||||
|         for(List<Object> list : sequenceInput) { |  | ||||||
|             for(Object token : list) { |  | ||||||
|                 String s = token.toString(); |  | ||||||
|                 if(gazeteer.contains(s)) { |  | ||||||
|                     ret.putScalar(wordList.indexOf(s),1); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<List<Writable>> mapSequence(List<List<Writable>> sequence) { |  | ||||||
|         INDArray arr = (INDArray) mapSequence((Object) sequence); |  | ||||||
|         return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(arr))); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return newColumnName; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Object map(Object input) { |  | ||||||
|         return gazeteer.contains(input.toString()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String outputColumnName() { |  | ||||||
|         return newColumnName; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String[] outputColumnNames() { |  | ||||||
|         return new String[]{newColumnName}; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String[] columnNames() { |  | ||||||
|         return new String[]{columnName()}; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String columnName() { |  | ||||||
|         return columnName; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public long[] outputShape() { |  | ||||||
|         return new long[]{wordList.size()}; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> vocabWords() { |  | ||||||
|         return wordList; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray transformFromObject(List<List<Object>> tokens) { |  | ||||||
|         return (INDArray) mapSequence(tokens); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray transformFrom(List<List<Writable>> tokens) { |  | ||||||
|         return (INDArray) mapSequence((Object) tokens); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,150 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.metadata.ColumnMetaData; |  | ||||||
| import org.datavec.api.transform.metadata.NDArrayMetaData; |  | ||||||
| import org.datavec.api.transform.transform.BaseColumnTransform; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.list.NDArrayList; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonCreator; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; |  | ||||||
| 
 |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform { |  | ||||||
| 
 |  | ||||||
|     private BagOfWordsTransform[] transforms; |  | ||||||
|     private String newColumnName; |  | ||||||
|     private List<String> vocabWords; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param columnName |  | ||||||
|      * @param transforms |  | ||||||
|      * @param newColumnName |  | ||||||
|      */ |  | ||||||
|     @JsonCreator |  | ||||||
|     public MultiNlpTransform(@JsonProperty("columnName") String columnName, |  | ||||||
|                              @JsonProperty("transforms") BagOfWordsTransform[] transforms, |  | ||||||
|                              @JsonProperty("newColumnName") String newColumnName) { |  | ||||||
|         super(columnName); |  | ||||||
|         this.transforms = transforms; |  | ||||||
|         this.vocabWords = transforms[0].vocabWords(); |  | ||||||
|         if(transforms.length > 1) { |  | ||||||
|             for(int i = 1; i < transforms.length; i++) { |  | ||||||
|                 if(!transforms[i].vocabWords().equals(vocabWords)) { |  | ||||||
|                     throw new IllegalArgumentException("Vocab words not consistent across transforms!"); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         this.newColumnName = newColumnName; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Object mapSequence(Object sequence) { |  | ||||||
|         NDArrayList ndArrayList = new NDArrayList(); |  | ||||||
|         for(BagOfWordsTransform bagofWordsTransform : transforms) { |  | ||||||
|             ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List<List<Object>>) sequence))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ndArrayList.array(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<List<Writable>> mapSequence(List<List<Writable>> sequence) { |  | ||||||
|      return Collections.singletonList(Collections.<Writable>singletonList(new NDArrayWritable(transformFrom(sequence)))); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { |  | ||||||
|         return new NDArrayMetaData(newName,outputShape()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Writable map(Writable columnWritable) { |  | ||||||
|         throw new UnsupportedOperationException("Only able to add for time series"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return newColumnName; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Object map(Object input) { |  | ||||||
|         throw new UnsupportedOperationException("Only able to add for time series"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public long[] outputShape() { |  | ||||||
|         long[] ret = new long[transforms[0].outputShape().length]; |  | ||||||
|         int validatedRank = transforms[0].outputShape().length; |  | ||||||
|         for(int i = 1; i < transforms.length; i++) { |  | ||||||
|             if(transforms[i].outputShape().length != validatedRank) { |  | ||||||
|                 throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         for(int i = 0; i < transforms.length; i++) { |  | ||||||
|             for(int j = 0; j < validatedRank; j++) |  | ||||||
|             ret[j] += transforms[i].outputShape()[j]; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> vocabWords() { |  | ||||||
|         return vocabWords; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray transformFromObject(List<List<Object>> tokens) { |  | ||||||
|         NDArrayList ndArrayList = new NDArrayList(); |  | ||||||
|         for(BagOfWordsTransform bagofWordsTransform : transforms) { |  | ||||||
|             INDArray arr2 = bagofWordsTransform.transformFromObject(tokens); |  | ||||||
|             arr2 = arr2.reshape(arr2.length()); |  | ||||||
|             NDArrayList newList = new NDArrayList(arr2,(int) arr2.length()); |  | ||||||
|             ndArrayList.addAll(newList);        } |  | ||||||
| 
 |  | ||||||
|         return ndArrayList.array(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray transformFrom(List<List<Writable>> tokens) { |  | ||||||
|         NDArrayList ndArrayList = new NDArrayList(); |  | ||||||
|         for(BagOfWordsTransform bagofWordsTransform : transforms) { |  | ||||||
|             INDArray arr2 = bagofWordsTransform.transformFrom(tokens); |  | ||||||
|             arr2 = arr2.reshape(arr2.length()); |  | ||||||
|             NDArrayList newList = new NDArrayList(arr2,(int) arr2.length()); |  | ||||||
|             ndArrayList.addAll(newList); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ndArrayList.array(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,226 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.EqualsAndHashCode; |  | ||||||
| import org.datavec.api.transform.metadata.ColumnMetaData; |  | ||||||
| import org.datavec.api.transform.metadata.NDArrayMetaData; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.api.transform.transform.BaseColumnTransform; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Counter; |  | ||||||
| import org.nd4j.common.util.MathUtils; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonCreator; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonInclude; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.Map; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @EqualsAndHashCode(callSuper = true, exclude = {"tokenizerFactory"}) |  | ||||||
| @JsonInclude(JsonInclude.Include.NON_NULL) |  | ||||||
| public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform { |  | ||||||
| 
 |  | ||||||
|     private String newColumName; |  | ||||||
|     private  Map<String,Integer> wordIndexMap; |  | ||||||
|     private Map<String,Double> weightMap; |  | ||||||
|     private boolean exceptionOnUnknown; |  | ||||||
|     private String tokenizerFactoryClass; |  | ||||||
|     private String preprocessorClass; |  | ||||||
|     private TokenizerFactory tokenizerFactory; |  | ||||||
| 
 |  | ||||||
|     @JsonCreator |  | ||||||
|     public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String columnName, |  | ||||||
|                                                          @JsonProperty("newColumnName") String newColumnName, |  | ||||||
|                                                          @JsonProperty("wordIndexMap") Map<String,Integer> wordIndexMap, |  | ||||||
|                                                          @JsonProperty("idfMap") Map<String,Double> idfMap, |  | ||||||
|                                                          @JsonProperty("exceptionOnUnknown") boolean exceptionOnUnknown, |  | ||||||
|                                                          @JsonProperty("tokenizerFactoryClass") String tokenizerFactoryClass, |  | ||||||
|                                                          @JsonProperty("preprocessorClass") String preprocessorClass) { |  | ||||||
|         super(columnName); |  | ||||||
|         this.newColumName = newColumnName; |  | ||||||
|         this.wordIndexMap = wordIndexMap; |  | ||||||
|         this.exceptionOnUnknown = exceptionOnUnknown; |  | ||||||
|         this.weightMap = idfMap; |  | ||||||
|         this.tokenizerFactoryClass = tokenizerFactoryClass; |  | ||||||
|         this.preprocessorClass = preprocessorClass; |  | ||||||
|         if(this.tokenizerFactoryClass == null) { |  | ||||||
|             this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName(); |  | ||||||
|         } |  | ||||||
|         try { |  | ||||||
|             tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance(); |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if(preprocessorClass != null){ |  | ||||||
|             try { |  | ||||||
|                 TokenPreProcess tpp = (TokenPreProcess) Class.forName(this.preprocessorClass).newInstance(); |  | ||||||
|                 tokenizerFactory.setTokenPreProcessor(tpp); |  | ||||||
|             } catch (Exception e){ |  | ||||||
|                 throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?"); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<Writable> map(List<Writable> writables) { |  | ||||||
|         Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName)); |  | ||||||
|         List<Writable> ret = new ArrayList<>(writables); |  | ||||||
|         ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString()))); |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Object map(Object input) { |  | ||||||
|         return convert(input.toString()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Object mapSequence(Object sequence) { |  | ||||||
|         return convert(sequence.toString()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Schema transform(Schema inputSchema) { |  | ||||||
|         Schema.Builder newSchema = new Schema.Builder(); |  | ||||||
|         for(int i = 0; i < inputSchema.numColumns(); i++) { |  | ||||||
|             if(inputSchema.getName(i).equals(this.columnName)) { |  | ||||||
|                 newSchema.addColumnNDArray(newColumName,new long[]{1,wordIndexMap.size()}); |  | ||||||
|             } |  | ||||||
|             else { |  | ||||||
|                 newSchema.addColumn(inputSchema.getMetaData(i)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return newSchema.build(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Convert the given text |  | ||||||
|      * in to an {@link INDArray} |  | ||||||
|      * using the {@link TokenizerFactory} |  | ||||||
|      * specified in the constructor. |  | ||||||
|      * @param text the text to transform |  | ||||||
|      * @return the created {@link INDArray} |  | ||||||
|      * based on the {@link #wordIndexMap} for the column indices |  | ||||||
|      * of the word. |  | ||||||
|      */ |  | ||||||
|     public INDArray convert(String text) { |  | ||||||
|         Tokenizer tokenizer = tokenizerFactory.create(text); |  | ||||||
|         List<String> tokens = tokenizer.getTokens(); |  | ||||||
|         INDArray create = Nd4j.create(1,wordIndexMap.size()); |  | ||||||
|         Counter<String> tokenizedCounter = new Counter<>(); |  | ||||||
| 
 |  | ||||||
|         for(int i = 0; i < tokens.size(); i++) { |  | ||||||
|             tokenizedCounter.incrementCount(tokens.get(i),1.0); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         for(int i = 0; i < tokens.size(); i++) { |  | ||||||
|             if(wordIndexMap.containsKey(tokens.get(i))) { |  | ||||||
|                 int idx = wordIndexMap.get(tokens.get(i)); |  | ||||||
|                 int count = (int) tokenizedCounter.getCount(tokens.get(i)); |  | ||||||
|                 double weight = tfidfWord(tokens.get(i),count,tokens.size()); |  | ||||||
|                 create.putScalar(idx,weight); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return create; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Calculate the tifdf for a word |  | ||||||
|      * given the word, word count, and document length |  | ||||||
|      * @param word the word to calculate |  | ||||||
|      * @param wordCount the word frequency |  | ||||||
|      * @param documentLength the number of words in the document |  | ||||||
|      * @return the tfidf weight for a given word |  | ||||||
|      */ |  | ||||||
|     public double tfidfWord(String word, long wordCount, long documentLength) { |  | ||||||
|         double tf = tfForWord(wordCount, documentLength); |  | ||||||
|         double idf = idfForWord(word); |  | ||||||
|         return MathUtils.tfidf(tf, idf); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Calculate the weight term frequency for a given |  | ||||||
|      * word normalized by the dcoument length |  | ||||||
|      * @param wordCount the word frequency |  | ||||||
|      * @param documentLength the number of words in the edocument |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     private double tfForWord(long wordCount, long documentLength) { |  | ||||||
|         return wordCount; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private double idfForWord(String word) { |  | ||||||
|         if(weightMap.containsKey(word)) |  | ||||||
|             return weightMap.get(word); |  | ||||||
|         return 0; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { |  | ||||||
|         return new NDArrayMetaData(outputColumnName(),new long[]{1,wordIndexMap.size()}); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String outputColumnName() { |  | ||||||
|         return newColumName; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String[] outputColumnNames() { |  | ||||||
|         return new String[]{newColumName}; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String[] columnNames() { |  | ||||||
|         return new String[]{columnName()}; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String columnName() { |  | ||||||
|         return columnName; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Writable map(Writable columnWritable) { |  | ||||||
|         return new NDArrayWritable(convert(columnWritable.toString())); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,107 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.uima; |  | ||||||
| 
 |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngine; |  | ||||||
| import org.apache.uima.analysis_engine.AnalysisEngineProcessException; |  | ||||||
| import org.apache.uima.cas.CAS; |  | ||||||
| import org.apache.uima.resource.ResourceInitializationException; |  | ||||||
| import org.apache.uima.util.CasPool; |  | ||||||
| 
 |  | ||||||
| public class UimaResource { |  | ||||||
| 
 |  | ||||||
|     private AnalysisEngine analysisEngine; |  | ||||||
|     private CasPool casPool; |  | ||||||
| 
 |  | ||||||
|     public UimaResource(AnalysisEngine analysisEngine) throws ResourceInitializationException { |  | ||||||
|         this.analysisEngine = analysisEngine; |  | ||||||
|         this.casPool = new CasPool(Runtime.getRuntime().availableProcessors() * 10, analysisEngine); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public UimaResource(AnalysisEngine analysisEngine, CasPool casPool) { |  | ||||||
|         this.analysisEngine = analysisEngine; |  | ||||||
|         this.casPool = casPool; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public AnalysisEngine getAnalysisEngine() { |  | ||||||
|         return analysisEngine; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public void setAnalysisEngine(AnalysisEngine analysisEngine) { |  | ||||||
|         this.analysisEngine = analysisEngine; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public CasPool getCasPool() { |  | ||||||
|         return casPool; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public void setCasPool(CasPool casPool) { |  | ||||||
|         this.casPool = casPool; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Use the given analysis engine and process the given text |  | ||||||
|      * You must release the return cas yourself |  | ||||||
|      * @param text the text to rpocess |  | ||||||
|      * @return the processed cas |  | ||||||
|      */ |  | ||||||
|     public CAS process(String text) { |  | ||||||
|         CAS cas = retrieve(); |  | ||||||
| 
 |  | ||||||
|         cas.setDocumentText(text); |  | ||||||
|         try { |  | ||||||
|             analysisEngine.process(cas); |  | ||||||
|         } catch (AnalysisEngineProcessException e) { |  | ||||||
|             if (text != null && !text.isEmpty()) |  | ||||||
|                 return process(text); |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return cas; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public CAS retrieve() { |  | ||||||
|         CAS ret = casPool.getCas(); |  | ||||||
|         try { |  | ||||||
|             return ret == null ? analysisEngine.newCAS() : ret; |  | ||||||
|         } catch (ResourceInitializationException e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public void release(CAS cas) { |  | ||||||
|         casPool.releaseCas(cas); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,77 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.vectorizer; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.Record; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; |  | ||||||
| 
 |  | ||||||
| import java.util.HashSet; |  | ||||||
| import java.util.Set; |  | ||||||
| 
 |  | ||||||
| public abstract class AbstractTfidfVectorizer<VECTOR_TYPE> extends TextVectorizer<VECTOR_TYPE> { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void doWithTokens(Tokenizer tokenizer) { |  | ||||||
|         Set<String> seen = new HashSet<>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) { |  | ||||||
|             String token = tokenizer.nextToken(); |  | ||||||
|             if (!stopWords.contains(token)) { |  | ||||||
|                 cache.incrementCount(token); |  | ||||||
|                 if (!seen.contains(token)) { |  | ||||||
|                     cache.incrementDocCount(token); |  | ||||||
|                 } |  | ||||||
|                 seen.add(token); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public TokenizerFactory createTokenizerFactory(Configuration conf) { |  | ||||||
|         String clazz = conf.get(TOKENIZER, DefaultTokenizerFactory.class.getName()); |  | ||||||
|         try { |  | ||||||
|             Class<? extends TokenizerFactory> tokenizerFactoryClazz = |  | ||||||
|                             (Class<? extends TokenizerFactory>) Class.forName(clazz); |  | ||||||
|             TokenizerFactory tf = tokenizerFactoryClazz.newInstance(); |  | ||||||
|             String preproc = conf.get(PREPROCESSOR, null); |  | ||||||
|             if(preproc != null){ |  | ||||||
|                 TokenPreProcess tpp = (TokenPreProcess) Class.forName(preproc).newInstance(); |  | ||||||
|                 tf.setTokenPreProcessor(tpp); |  | ||||||
|             } |  | ||||||
|             return tf; |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public abstract VECTOR_TYPE createVector(Object[] args); |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public abstract VECTOR_TYPE fitTransform(RecordReader reader); |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public abstract VECTOR_TYPE transform(Record record); |  | ||||||
| } |  | ||||||
| @ -1,121 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.vectorizer; |  | ||||||
| 
 |  | ||||||
| import lombok.Getter; |  | ||||||
| import org.nd4j.common.primitives.Counter; |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.Record; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.api.vector.Vectorizer; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.nlp.metadata.DefaultVocabCache; |  | ||||||
| import org.datavec.nlp.metadata.VocabCache; |  | ||||||
| import org.datavec.nlp.stopwords.StopWords; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.Tokenizer; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; |  | ||||||
| 
 |  | ||||||
| import java.util.Collection; |  | ||||||
| 
 |  | ||||||
| public abstract class TextVectorizer<VECTOR_TYPE> implements Vectorizer<VECTOR_TYPE> { |  | ||||||
| 
 |  | ||||||
|     protected TokenizerFactory tokenizerFactory; |  | ||||||
|     protected int minWordFrequency = 0; |  | ||||||
|     public final static String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency"; |  | ||||||
|     public final static String STOP_WORDS = "org.nd4j.nlp.stopwords"; |  | ||||||
|     public final static String TOKENIZER = "org.datavec.nlp.tokenizerfactory"; |  | ||||||
|     public static final String PREPROCESSOR = "org.datavec.nlp.preprocessor"; |  | ||||||
|     public final static String VOCAB_CACHE = "org.datavec.nlp.vocabcache"; |  | ||||||
|     protected Collection<String> stopWords; |  | ||||||
|     @Getter |  | ||||||
|     protected VocabCache cache; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(Configuration conf) { |  | ||||||
|         tokenizerFactory = createTokenizerFactory(conf); |  | ||||||
|         minWordFrequency = conf.getInt(MIN_WORD_FREQUENCY, 5); |  | ||||||
|         if(conf.get(STOP_WORDS) != null) |  | ||||||
|             stopWords = conf.getStringCollection(STOP_WORDS); |  | ||||||
|         if (stopWords == null) |  | ||||||
|             stopWords = StopWords.getStopWords(); |  | ||||||
| 
 |  | ||||||
|         String clazz = conf.get(VOCAB_CACHE, DefaultVocabCache.class.getName()); |  | ||||||
|         try { |  | ||||||
|             Class<? extends VocabCache> tokenizerFactoryClazz = (Class<? extends VocabCache>) Class.forName(clazz); |  | ||||||
|             cache = tokenizerFactoryClazz.newInstance(); |  | ||||||
|             cache.initialize(conf); |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void fit(RecordReader reader) { |  | ||||||
|         fit(reader, null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void fit(RecordReader reader, RecordCallBack callBack) { |  | ||||||
|         while (reader.hasNext()) { |  | ||||||
|             Record record = reader.nextRecord(); |  | ||||||
|             String s = toString(record.getRecord()); |  | ||||||
|             Tokenizer tokenizer = tokenizerFactory.create(s); |  | ||||||
|             doWithTokens(tokenizer); |  | ||||||
|             if (callBack != null) |  | ||||||
|                 callBack.onRecord(record); |  | ||||||
|             cache.incrementNumDocs(1); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     protected Counter<String> wordFrequenciesForRecord(Collection<Writable> record) { |  | ||||||
|         String s = toString(record); |  | ||||||
|         Tokenizer tokenizer = tokenizerFactory.create(s); |  | ||||||
|         Counter<String> ret = new Counter<>(); |  | ||||||
|         while (tokenizer.hasMoreTokens()) |  | ||||||
|             ret.incrementCount(tokenizer.nextToken(), 1.0); |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     protected String toString(Collection<Writable> record) { |  | ||||||
|         StringBuilder sb = new StringBuilder(); |  | ||||||
|         for(Writable w : record){ |  | ||||||
|             sb.append(w.toString()); |  | ||||||
|         } |  | ||||||
|         return sb.toString(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Increment counts, add to collection,... |  | ||||||
|      * @param tokenizer |  | ||||||
|      */ |  | ||||||
|     public abstract void doWithTokens(Tokenizer tokenizer); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create tokenizer factory based on the configuration |  | ||||||
|      * @param conf the configuration to use |  | ||||||
|      * @return the tokenizer factory based on the configuration |  | ||||||
|      */ |  | ||||||
|     public abstract TokenizerFactory createTokenizerFactory(Configuration conf); |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,105 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.vectorizer; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.nd4j.common.primitives.Counter; |  | ||||||
| import org.datavec.api.records.Record; |  | ||||||
| import org.datavec.api.records.metadata.RecordMetaDataURI; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class TfidfVectorizer extends AbstractTfidfVectorizer<INDArray> { |  | ||||||
|     /** |  | ||||||
|      * Default: True.<br> |  | ||||||
|      * If true: use idf(d, t) = log [ (1 + n) / (1 + df(d, t)) ] + 1<br> |  | ||||||
|      * If false: use idf(t) = log [ n / df(t) ] + 1<br> |  | ||||||
|      */ |  | ||||||
|     public static final String SMOOTH_IDF = "org.datavec.nlp.TfidfVectorizer.smooth_idf"; |  | ||||||
| 
 |  | ||||||
|     protected boolean smooth_idf; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray createVector(Object[] args) { |  | ||||||
|         Counter<String> docFrequencies = (Counter<String>) args[0]; |  | ||||||
|         double[] vector = new double[cache.vocabWords().size()]; |  | ||||||
|         for (int i = 0; i < cache.vocabWords().size(); i++) { |  | ||||||
|             String word = cache.wordAt(i); |  | ||||||
|             double freq = docFrequencies.getCount(word); |  | ||||||
|             vector[i] = cache.tfidf(word, freq, smooth_idf); |  | ||||||
|         } |  | ||||||
|         return Nd4j.create(vector); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray fitTransform(RecordReader reader) { |  | ||||||
|         return fitTransform(reader, null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray fitTransform(final RecordReader reader, RecordCallBack callBack) { |  | ||||||
|         final List<Record> records = new ArrayList<>(); |  | ||||||
|         fit(reader, new RecordCallBack() { |  | ||||||
|             @Override |  | ||||||
|             public void onRecord(Record record) { |  | ||||||
|                 records.add(record); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         if (records.isEmpty()) |  | ||||||
|             throw new IllegalStateException("No records found!"); |  | ||||||
|         INDArray ret = Nd4j.create(records.size(), cache.vocabWords().size()); |  | ||||||
|         int i = 0; |  | ||||||
|         for (Record record : records) { |  | ||||||
|             INDArray transformed = transform(record); |  | ||||||
|             org.datavec.api.records.impl.Record transformedRecord = new org.datavec.api.records.impl.Record( |  | ||||||
|                             Arrays.asList(new NDArrayWritable(transformed), |  | ||||||
|                                             record.getRecord().get(record.getRecord().size() - 1)), |  | ||||||
|                             new RecordMetaDataURI(record.getMetaData().getURI(), reader.getClass())); |  | ||||||
|             ret.putRow(i++, transformed); |  | ||||||
|             if (callBack != null) { |  | ||||||
|                 callBack.onRecord(transformedRecord); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray transform(Record record) { |  | ||||||
|         Counter<String> wordFrequencies = wordFrequenciesForRecord(record.getRecord()); |  | ||||||
|         return createVector(new Object[] {wordFrequencies}); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void initialize(Configuration conf){ |  | ||||||
|         super.initialize(conf); |  | ||||||
|         this.smooth_idf = conf.getBoolean(SMOOTH_IDF, true); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,194 +0,0 @@ | |||||||
| a |  | ||||||
| ----s |  | ||||||
| act |  | ||||||
| "the |  | ||||||
| "The |  | ||||||
| about |  | ||||||
| above |  | ||||||
| after |  | ||||||
| again |  | ||||||
| against |  | ||||||
| all |  | ||||||
| am |  | ||||||
| an |  | ||||||
| and |  | ||||||
| any |  | ||||||
| are |  | ||||||
| aren't |  | ||||||
| as |  | ||||||
| at |  | ||||||
| be |  | ||||||
| because |  | ||||||
| been |  | ||||||
| before |  | ||||||
| being |  | ||||||
| below |  | ||||||
| between |  | ||||||
| both |  | ||||||
| but |  | ||||||
| by |  | ||||||
| can't |  | ||||||
| cannot |  | ||||||
| could |  | ||||||
| couldn't |  | ||||||
| did |  | ||||||
| didn't |  | ||||||
| do |  | ||||||
| does |  | ||||||
| doesn't |  | ||||||
| doing |  | ||||||
| don't |  | ||||||
| down |  | ||||||
| during |  | ||||||
| each |  | ||||||
| few |  | ||||||
| for |  | ||||||
| from |  | ||||||
| further |  | ||||||
| had |  | ||||||
| hadn't |  | ||||||
| has |  | ||||||
| hasn't |  | ||||||
| have |  | ||||||
| haven't |  | ||||||
| having |  | ||||||
| he |  | ||||||
| he'd |  | ||||||
| he'll |  | ||||||
| he's |  | ||||||
| her |  | ||||||
| here |  | ||||||
| here's |  | ||||||
| hers |  | ||||||
| herself |  | ||||||
| him |  | ||||||
| himself |  | ||||||
| his |  | ||||||
| how |  | ||||||
| how's |  | ||||||
| i |  | ||||||
| i'd |  | ||||||
| i'll |  | ||||||
| i'm |  | ||||||
| i've |  | ||||||
| if |  | ||||||
| in |  | ||||||
| into |  | ||||||
| is |  | ||||||
| isn't |  | ||||||
| it |  | ||||||
| it's |  | ||||||
| its |  | ||||||
| itself |  | ||||||
| let's |  | ||||||
| me |  | ||||||
| more |  | ||||||
| most |  | ||||||
| mustn't |  | ||||||
| my |  | ||||||
| myself |  | ||||||
| no |  | ||||||
| nor |  | ||||||
| not |  | ||||||
| of |  | ||||||
| off |  | ||||||
| on |  | ||||||
| once |  | ||||||
| only |  | ||||||
| or |  | ||||||
| other |  | ||||||
| ought |  | ||||||
| our |  | ||||||
| ours  |  | ||||||
| ourselves |  | ||||||
| out |  | ||||||
| over |  | ||||||
| own |  | ||||||
| put |  | ||||||
| same |  | ||||||
| shan't |  | ||||||
| she |  | ||||||
| she'd |  | ||||||
| she'll |  | ||||||
| she's |  | ||||||
| should |  | ||||||
| somebody |  | ||||||
| something |  | ||||||
| shouldn't |  | ||||||
| so |  | ||||||
| some |  | ||||||
| such |  | ||||||
| take |  | ||||||
| than |  | ||||||
| that |  | ||||||
| that's |  | ||||||
| the |  | ||||||
| their |  | ||||||
| theirs |  | ||||||
| them |  | ||||||
| themselves |  | ||||||
| then |  | ||||||
| there |  | ||||||
| there's |  | ||||||
| these |  | ||||||
| they |  | ||||||
| they'd |  | ||||||
| they'll |  | ||||||
| they're |  | ||||||
| they've |  | ||||||
| this |  | ||||||
| those |  | ||||||
| through |  | ||||||
| to |  | ||||||
| too |  | ||||||
| under |  | ||||||
| until |  | ||||||
| up |  | ||||||
| very |  | ||||||
| was |  | ||||||
| wasn't |  | ||||||
| we |  | ||||||
| we'd |  | ||||||
| we'll |  | ||||||
| we're |  | ||||||
| we've |  | ||||||
| were |  | ||||||
| weren't |  | ||||||
| what |  | ||||||
| what's |  | ||||||
| when |  | ||||||
| when's |  | ||||||
| where |  | ||||||
| where's |  | ||||||
| which |  | ||||||
| while |  | ||||||
| who |  | ||||||
| who's |  | ||||||
| whom |  | ||||||
| why |  | ||||||
| why's |  | ||||||
| will |  | ||||||
| with |  | ||||||
| without |  | ||||||
| won't |  | ||||||
| would |  | ||||||
| wouldn't |  | ||||||
| you |  | ||||||
| you'd |  | ||||||
| you'll |  | ||||||
| you're |  | ||||||
| you've |  | ||||||
| your |  | ||||||
| yours |  | ||||||
| yourself |  | ||||||
| yourselves |  | ||||||
| . |  | ||||||
| ? |  | ||||||
| ! |  | ||||||
| , |  | ||||||
| + |  | ||||||
| = |  | ||||||
| also |  | ||||||
| - |  | ||||||
| ; |  | ||||||
| : |  | ||||||
| @ -1,46 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| package org.datavec.nlp; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Set<Class<?>> getExclusions() { |  | ||||||
| 		//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) |  | ||||||
| 		return new HashSet<>(); |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected String getPackageName() { |  | ||||||
| 		return "org.datavec.nlp"; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Class<?> getBaseClass() { |  | ||||||
| 		return BaseND4JTest.class; |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,130 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.reader; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.Record; |  | ||||||
| import org.datavec.api.records.reader.RecordReader; |  | ||||||
| import org.datavec.api.split.CollectionInputSplit; |  | ||||||
| import org.datavec.api.split.FileSplit; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.nlp.vectorizer.TfidfVectorizer; |  | ||||||
| import org.junit.Rule; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.junit.rules.TemporaryFolder; |  | ||||||
| import org.nd4j.common.io.ClassPathResource; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.net.URI; |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.*; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class TfidfRecordReaderTest { |  | ||||||
| 
 |  | ||||||
|     @Rule |  | ||||||
|     public TemporaryFolder testDir = new TemporaryFolder(); |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testReader() throws Exception { |  | ||||||
|         TfidfVectorizer vectorizer = new TfidfVectorizer(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); |  | ||||||
|         conf.setBoolean(RecordReader.APPEND_LABEL, true); |  | ||||||
|         vectorizer.initialize(conf); |  | ||||||
|         TfidfRecordReader reader = new TfidfRecordReader(); |  | ||||||
|         File f = testDir.newFolder(); |  | ||||||
|         new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); |  | ||||||
|         List<URI> u = new ArrayList<>(); |  | ||||||
|         for(File f2 : f.listFiles()){ |  | ||||||
|             if(f2.isDirectory()){ |  | ||||||
|                 for(File f3 : f2.listFiles()){ |  | ||||||
|                     u.add(f3.toURI()); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 u.add(f2.toURI()); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         Collections.sort(u); |  | ||||||
|         CollectionInputSplit c = new CollectionInputSplit(u); |  | ||||||
|         reader.initialize(conf, c); |  | ||||||
|         int count = 0; |  | ||||||
|         int[] labelAssertions = new int[3]; |  | ||||||
|         while (reader.hasNext()) { |  | ||||||
|             Collection<Writable> record = reader.next(); |  | ||||||
|             Iterator<Writable> recordIter = record.iterator(); |  | ||||||
|             NDArrayWritable writable = (NDArrayWritable) recordIter.next(); |  | ||||||
|             labelAssertions[count] = recordIter.next().toInt(); |  | ||||||
|             count++; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         assertArrayEquals(new int[] {0, 1, 2}, labelAssertions); |  | ||||||
|         assertEquals(3, reader.getLabels().size()); |  | ||||||
|         assertEquals(3, count); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testRecordMetaData() throws Exception { |  | ||||||
|         TfidfVectorizer vectorizer = new TfidfVectorizer(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); |  | ||||||
|         conf.setBoolean(RecordReader.APPEND_LABEL, true); |  | ||||||
|         vectorizer.initialize(conf); |  | ||||||
|         TfidfRecordReader reader = new TfidfRecordReader(); |  | ||||||
|         File f = testDir.newFolder(); |  | ||||||
|         new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); |  | ||||||
|         reader.initialize(conf, new FileSplit(f)); |  | ||||||
| 
 |  | ||||||
|         while (reader.hasNext()) { |  | ||||||
|             Record record = reader.nextRecord(); |  | ||||||
|             assertNotNull(record.getMetaData().getURI()); |  | ||||||
|             assertEquals(record.getMetaData().getReaderClass(), TfidfRecordReader.class); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testReadRecordFromMetaData() throws Exception { |  | ||||||
|         TfidfVectorizer vectorizer = new TfidfVectorizer(); |  | ||||||
|         Configuration conf = new Configuration(); |  | ||||||
|         conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); |  | ||||||
|         conf.setBoolean(RecordReader.APPEND_LABEL, true); |  | ||||||
|         vectorizer.initialize(conf); |  | ||||||
|         TfidfRecordReader reader = new TfidfRecordReader(); |  | ||||||
|         File f = testDir.newFolder(); |  | ||||||
|         new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); |  | ||||||
|         reader.initialize(conf, new FileSplit(f)); |  | ||||||
| 
 |  | ||||||
|         Record record = reader.nextRecord(); |  | ||||||
| 
 |  | ||||||
|         Record reread = reader.loadFromMetaData(record.getMetaData()); |  | ||||||
| 
 |  | ||||||
|         assertEquals(record.getRecord().size(), 2); |  | ||||||
|         assertEquals(reread.getRecord().size(), 2); |  | ||||||
|         assertEquals(record.getRecord().get(0), reread.getRecord().get(0)); |  | ||||||
|         assertEquals(record.getRecord().get(1), reread.getRecord().get(1)); |  | ||||||
|         assertEquals(record.getMetaData(), reread.getMetaData()); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,96 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.SequenceSchema; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.local.transforms.LocalTransformExecutor; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class TestGazeteerTransform { |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testGazeteerTransform(){ |  | ||||||
| 
 |  | ||||||
|         String[] corpus = { |  | ||||||
|                 "hello I like apple".toLowerCase(), |  | ||||||
|                 "cherry date eggplant potato".toLowerCase() |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         //Gazeteer transform: basically 0/1 if word is present. Assumes already tokenized input |  | ||||||
|         List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant"); |  | ||||||
| 
 |  | ||||||
|         GazeteerTransform t = new GazeteerTransform("words", "out", words); |  | ||||||
| 
 |  | ||||||
|         SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder() |  | ||||||
|                 .addColumnString("words").build(); |  | ||||||
| 
 |  | ||||||
|         TransformProcess tp = new TransformProcess.Builder(schema) |  | ||||||
|                 .transform(t) |  | ||||||
|                 .build(); |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> input = new ArrayList<>(); |  | ||||||
|         for(String s : corpus){ |  | ||||||
|             String[] split = s.split(" "); |  | ||||||
|             List<List<Writable>> seq = new ArrayList<>(); |  | ||||||
|             for(String s2 : split){ |  | ||||||
|                 seq.add(Collections.<Writable>singletonList(new Text(s2))); |  | ||||||
|             } |  | ||||||
|             input.add(seq); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp); |  | ||||||
| 
 |  | ||||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); |  | ||||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0}); |  | ||||||
|         INDArray exp1 = Nd4j.create(new float[]{0, 0, 1, 1, 1}); |  | ||||||
| 
 |  | ||||||
|         assertEquals(exp0, arr0); |  | ||||||
|         assertEquals(exp1, arr1); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         String json = tp.toJson(); |  | ||||||
|         TransformProcess tp2 = TransformProcess.fromJson(json); |  | ||||||
|         assertEquals(tp, tp2); |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp); |  | ||||||
|         INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); |  | ||||||
|         INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         assertEquals(exp0, arr0a); |  | ||||||
|         assertEquals(exp1, arr1a); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,96 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.SequenceSchema; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.local.transforms.LocalTransformExecutor; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class TestMultiNLPTransform { |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void test(){ |  | ||||||
| 
 |  | ||||||
|         List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant"); |  | ||||||
|         GazeteerTransform t1 = new GazeteerTransform("words", "out", words); |  | ||||||
|         GazeteerTransform t2 = new GazeteerTransform("out", "out", words); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         MultiNlpTransform multi = new MultiNlpTransform("text", new BagOfWordsTransform[]{t1, t2}, "out"); |  | ||||||
| 
 |  | ||||||
|         String[] corpus = { |  | ||||||
|                 "hello I like apple".toLowerCase(), |  | ||||||
|                 "date eggplant potato".toLowerCase() |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> input = new ArrayList<>(); |  | ||||||
|         for(String s : corpus){ |  | ||||||
|             String[] split = s.split(" "); |  | ||||||
|             List<List<Writable>> seq = new ArrayList<>(); |  | ||||||
|             for(String s2 : split){ |  | ||||||
|                 seq.add(Collections.<Writable>singletonList(new Text(s2))); |  | ||||||
|             } |  | ||||||
|             input.add(seq); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder() |  | ||||||
|                 .addColumnString("text").build(); |  | ||||||
| 
 |  | ||||||
|         TransformProcess tp = new TransformProcess.Builder(schema) |  | ||||||
|                 .transform(multi) |  | ||||||
|                 .build(); |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp); |  | ||||||
| 
 |  | ||||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); |  | ||||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0, 1, 0, 0, 0, 0}); |  | ||||||
|         INDArray exp1 = Nd4j.create(new float[]{0, 0, 0, 1, 1, 0, 0, 0, 1, 1}); |  | ||||||
| 
 |  | ||||||
|         assertEquals(exp0, arr0); |  | ||||||
|         assertEquals(exp1, arr1); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         String json = tp.toJson(); |  | ||||||
|         TransformProcess tp2 = TransformProcess.fromJson(json); |  | ||||||
|         assertEquals(tp, tp2); |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp); |  | ||||||
|         INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); |  | ||||||
|         INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         assertEquals(exp0, arr0a); |  | ||||||
|         assertEquals(exp1, arr1a); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,414 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.nlp.transforms; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.conf.Configuration; |  | ||||||
| import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.SequenceSchema; |  | ||||||
| import org.datavec.api.writable.NDArrayWritable; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.local.transforms.LocalTransformExecutor; |  | ||||||
| import org.datavec.nlp.metadata.VocabCache; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizer.preprocessor.LowerCasePreProcessor; |  | ||||||
| import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; |  | ||||||
| import org.datavec.nlp.vectorizer.TfidfVectorizer; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataType; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Triple; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| import static org.datavec.nlp.vectorizer.TextVectorizer.*; |  | ||||||
| import static org.junit.Assert.assertArrayEquals; |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class TokenizerBagOfWordsTermSequenceIndexTransformTest { |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testSequenceExecution() { |  | ||||||
|         //credit: https://stackoverflow.com/questions/23792781/tf-idf-feature-weights-using-sklearn-feature-extraction-text-tfidfvectorizer |  | ||||||
|         String[] corpus = { |  | ||||||
|                 "This is very strange".toLowerCase(), |  | ||||||
|                 "This is very nice".toLowerCase() |  | ||||||
|         }; |  | ||||||
|         //{'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0} |  | ||||||
| 
 |  | ||||||
|         /* |  | ||||||
|          ## Reproduce with: |  | ||||||
|          from sklearn.feature_extraction.text import TfidfVectorizer |  | ||||||
|          corpus = ["This is very strange", "This is very nice"] |  | ||||||
| 
 |  | ||||||
|          ## SMOOTH = FALSE case: |  | ||||||
|          vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False) |  | ||||||
|          X = vectorizer.fit_transform(corpus) |  | ||||||
|          idf = vectorizer.idf_ |  | ||||||
|          print(dict(zip(vectorizer.get_feature_names(), idf))) |  | ||||||
| 
 |  | ||||||
|          newText = ["This is very strange", "This is very nice"] |  | ||||||
|          out = vectorizer.transform(newText) |  | ||||||
|          print(out) |  | ||||||
| 
 |  | ||||||
|          {'is': 1.0, 'nice': 1.6931471805599454, 'strange': 1.6931471805599454, 'this': 1.0, 'very': 1.0} |  | ||||||
|          (0, 4)	1.0 |  | ||||||
|          (0, 3)	1.0 |  | ||||||
|          (0, 2)	1.6931471805599454 |  | ||||||
|          (0, 0)	1.0 |  | ||||||
|          (1, 4)	1.0 |  | ||||||
|          (1, 3)	1.0 |  | ||||||
|          (1, 1)	1.6931471805599454 |  | ||||||
|          (1, 0)	1.0 |  | ||||||
| 
 |  | ||||||
|          ## SMOOTH + TRUE case: |  | ||||||
|          {'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0} |  | ||||||
|           (0, 4)	1.0 |  | ||||||
|           (0, 3)	1.0 |  | ||||||
|           (0, 2)	1.4054651081081644 |  | ||||||
|           (0, 0)	1.0 |  | ||||||
|           (1, 4)	1.0 |  | ||||||
|           (1, 3)	1.0 |  | ||||||
|           (1, 1)	1.4054651081081644 |  | ||||||
|           (1, 0)	1.0 |  | ||||||
|          */ |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> input = new ArrayList<>(); |  | ||||||
|         input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1])))); |  | ||||||
| 
 |  | ||||||
|         // First: Check TfidfVectorizer vs. scikit: |  | ||||||
| 
 |  | ||||||
|         Map<String,Double> idfMapNoSmooth = new HashMap<>(); |  | ||||||
|         idfMapNoSmooth.put("is",1.0); |  | ||||||
|         idfMapNoSmooth.put("nice",1.6931471805599454); |  | ||||||
|         idfMapNoSmooth.put("strange",1.6931471805599454); |  | ||||||
|         idfMapNoSmooth.put("this",1.0); |  | ||||||
|         idfMapNoSmooth.put("very",1.0); |  | ||||||
| 
 |  | ||||||
|         Map<String,Double> idfMapSmooth = new HashMap<>(); |  | ||||||
|         idfMapSmooth.put("is",1.0); |  | ||||||
|         idfMapSmooth.put("nice",1.4054651081081644); |  | ||||||
|         idfMapSmooth.put("strange",1.4054651081081644); |  | ||||||
|         idfMapSmooth.put("this",1.0); |  | ||||||
|         idfMapSmooth.put("very",1.0); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         TfidfVectorizer tfidfVectorizer = new TfidfVectorizer(); |  | ||||||
|         Configuration configuration = new Configuration(); |  | ||||||
|         configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); |  | ||||||
|         configuration.set(MIN_WORD_FREQUENCY,"1"); |  | ||||||
|         configuration.set(STOP_WORDS,""); |  | ||||||
|         configuration.set(TfidfVectorizer.SMOOTH_IDF, "false"); |  | ||||||
| 
 |  | ||||||
|         tfidfVectorizer.initialize(configuration); |  | ||||||
| 
 |  | ||||||
|         CollectionRecordReader collectionRecordReader = new CollectionRecordReader(input.get(0)); |  | ||||||
|         INDArray array = tfidfVectorizer.fitTransform(collectionRecordReader); |  | ||||||
| 
 |  | ||||||
|         INDArray expNoSmooth = Nd4j.create(DataType.FLOAT, 2, 5); |  | ||||||
|         VocabCache vc = tfidfVectorizer.getCache(); |  | ||||||
|         expNoSmooth.putScalar(0, vc.wordIndex("very"), 1.0); |  | ||||||
|         expNoSmooth.putScalar(0, vc.wordIndex("this"), 1.0); |  | ||||||
|         expNoSmooth.putScalar(0, vc.wordIndex("strange"), 1.6931471805599454); |  | ||||||
|         expNoSmooth.putScalar(0, vc.wordIndex("is"), 1.0); |  | ||||||
| 
 |  | ||||||
|         expNoSmooth.putScalar(1, vc.wordIndex("very"), 1.0); |  | ||||||
|         expNoSmooth.putScalar(1, vc.wordIndex("this"), 1.0); |  | ||||||
|         expNoSmooth.putScalar(1, vc.wordIndex("nice"), 1.6931471805599454); |  | ||||||
|         expNoSmooth.putScalar(1, vc.wordIndex("is"), 1.0); |  | ||||||
| 
 |  | ||||||
|         assertEquals(expNoSmooth, array); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         //------------------------------------------------------------ |  | ||||||
|         //Smooth version: |  | ||||||
|         tfidfVectorizer = new TfidfVectorizer(); |  | ||||||
|         configuration = new Configuration(); |  | ||||||
|         configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); |  | ||||||
|         configuration.set(MIN_WORD_FREQUENCY,"1"); |  | ||||||
|         configuration.set(STOP_WORDS,""); |  | ||||||
|         configuration.set(TfidfVectorizer.SMOOTH_IDF, "true"); |  | ||||||
| 
 |  | ||||||
|         tfidfVectorizer.initialize(configuration); |  | ||||||
| 
 |  | ||||||
|         collectionRecordReader.reset(); |  | ||||||
|         array = tfidfVectorizer.fitTransform(collectionRecordReader); |  | ||||||
| 
 |  | ||||||
|         INDArray expSmooth = Nd4j.create(DataType.FLOAT, 2, 5); |  | ||||||
|         expSmooth.putScalar(0, vc.wordIndex("very"), 1.0); |  | ||||||
|         expSmooth.putScalar(0, vc.wordIndex("this"), 1.0); |  | ||||||
|         expSmooth.putScalar(0, vc.wordIndex("strange"), 1.4054651081081644); |  | ||||||
|         expSmooth.putScalar(0, vc.wordIndex("is"), 1.0); |  | ||||||
| 
 |  | ||||||
|         expSmooth.putScalar(1, vc.wordIndex("very"), 1.0); |  | ||||||
|         expSmooth.putScalar(1, vc.wordIndex("this"), 1.0); |  | ||||||
|         expSmooth.putScalar(1, vc.wordIndex("nice"), 1.4054651081081644); |  | ||||||
|         expSmooth.putScalar(1, vc.wordIndex("is"), 1.0); |  | ||||||
| 
 |  | ||||||
|         assertEquals(expSmooth, array); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         ////////////////////////////////////////////////////////// |  | ||||||
| 
 |  | ||||||
|         //Second: Check transform vs scikit/TfidfVectorizer |  | ||||||
| 
 |  | ||||||
|         List<String> vocab = new ArrayList<>(5);    //Arrays.asList("is","nice","strange","this","very"); |  | ||||||
|         for( int i=0; i<5; i++ ){ |  | ||||||
|             vocab.add(vc.wordAt(i)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         String inputColumnName = "input"; |  | ||||||
|         String outputColumnName = "output"; |  | ||||||
|         Map<String,Integer> wordIndexMap = new HashMap<>(); |  | ||||||
|         for(int i = 0; i < vocab.size(); i++) { |  | ||||||
|             wordIndexMap.put(vocab.get(i),i); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         TokenizerBagOfWordsTermSequenceIndexTransform tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform( |  | ||||||
|                 inputColumnName, |  | ||||||
|                 outputColumnName, |  | ||||||
|                 wordIndexMap, |  | ||||||
|                 idfMapNoSmooth, |  | ||||||
|                 false, |  | ||||||
|                 null, null); |  | ||||||
| 
 |  | ||||||
|         SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder(); |  | ||||||
|         sequenceSchemaBuilder.addColumnString("input"); |  | ||||||
|         SequenceSchema schema = sequenceSchemaBuilder.build(); |  | ||||||
|         assertEquals("input",schema.getName(0)); |  | ||||||
| 
 |  | ||||||
|         TransformProcess transformProcess = new TransformProcess.Builder(schema) |  | ||||||
|                 .transform(tokenizerBagOfWordsTermSequenceIndexTransform) |  | ||||||
|                 .build(); |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         //System.out.println(execute); |  | ||||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); |  | ||||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         assertEquals(expNoSmooth.getRow(0, true), arr0); |  | ||||||
|         assertEquals(expNoSmooth.getRow(1, true), arr1); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         //-------------------------------- |  | ||||||
|         //Check smooth: |  | ||||||
| 
 |  | ||||||
|         tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform( |  | ||||||
|                 inputColumnName, |  | ||||||
|                 outputColumnName, |  | ||||||
|                 wordIndexMap, |  | ||||||
|                 idfMapSmooth, |  | ||||||
|                 false, |  | ||||||
|                 null, null); |  | ||||||
| 
 |  | ||||||
|         schema = (SequenceSchema) new SequenceSchema.Builder().addColumnString("input").build(); |  | ||||||
| 
 |  | ||||||
|         transformProcess = new TransformProcess.Builder(schema) |  | ||||||
|                 .transform(tokenizerBagOfWordsTermSequenceIndexTransform) |  | ||||||
|                 .build(); |  | ||||||
| 
 |  | ||||||
|         execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); |  | ||||||
| 
 |  | ||||||
|         arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); |  | ||||||
|         arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         assertEquals(expSmooth.getRow(0, true), arr0); |  | ||||||
|         assertEquals(expSmooth.getRow(1, true), arr1); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         //Test JSON serialization: |  | ||||||
| 
 |  | ||||||
|         String json = transformProcess.toJson(); |  | ||||||
|         TransformProcess fromJson = TransformProcess.fromJson(json); |  | ||||||
|         assertEquals(transformProcess, fromJson); |  | ||||||
|         List<List<List<Writable>>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, fromJson); |  | ||||||
| 
 |  | ||||||
|         INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); |  | ||||||
|         INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         assertEquals(expSmooth.getRow(0, true), arr0a); |  | ||||||
|         assertEquals(expSmooth.getRow(1, true), arr1a); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void additionalTest(){ |  | ||||||
|         /* |  | ||||||
|         ## To reproduce: |  | ||||||
|         from sklearn.feature_extraction.text import TfidfVectorizer |  | ||||||
|         corpus = [ |  | ||||||
|            'This is the first document', |  | ||||||
|            'This document is the second document', |  | ||||||
|            'And this is the third one', |  | ||||||
|            'Is this the first document', |  | ||||||
|         ] |  | ||||||
|         vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False) |  | ||||||
|         X = vectorizer.fit_transform(corpus) |  | ||||||
|         print(vectorizer.get_feature_names()) |  | ||||||
| 
 |  | ||||||
|         out = vectorizer.transform(corpus) |  | ||||||
|         print(out) |  | ||||||
| 
 |  | ||||||
|         ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] |  | ||||||
|           (0, 8)	1.0 |  | ||||||
|           (0, 6)	1.0 |  | ||||||
|           (0, 3)	1.0 |  | ||||||
|           (0, 2)	1.6931471805599454 |  | ||||||
|           (0, 1)	1.2876820724517808 |  | ||||||
|           (1, 8)	1.0 |  | ||||||
|           (1, 6)	1.0 |  | ||||||
|           (1, 5)	2.386294361119891 |  | ||||||
|           (1, 3)	1.0 |  | ||||||
|           (1, 1)	2.5753641449035616 |  | ||||||
|           (2, 8)	1.0 |  | ||||||
|           (2, 7)	2.386294361119891 |  | ||||||
|           (2, 6)	1.0 |  | ||||||
|           (2, 4)	2.386294361119891 |  | ||||||
|           (2, 3)	1.0 |  | ||||||
|           (2, 0)	2.386294361119891 |  | ||||||
|           (3, 8)	1.0 |  | ||||||
|           (3, 6)	1.0 |  | ||||||
|           (3, 3)	1.0 |  | ||||||
|           (3, 2)	1.6931471805599454 |  | ||||||
|           (3, 1)	1.2876820724517808 |  | ||||||
|           {'and': 2.386294361119891, 'document': 1.2876820724517808, 'first': 1.6931471805599454, 'is': 1.0, 'one': 2.386294361119891, 'second': 2.386294361119891, 'the': 1.0, 'third': 2.386294361119891, 'this': 1.0} |  | ||||||
|          */ |  | ||||||
| 
 |  | ||||||
|         String[] corpus = { |  | ||||||
|                 "This is the first document", |  | ||||||
|                 "This document is the second document", |  | ||||||
|                 "And this is the third one", |  | ||||||
|                 "Is this the first document"}; |  | ||||||
| 
 |  | ||||||
|         TfidfVectorizer tfidfVectorizer = new TfidfVectorizer(); |  | ||||||
|         Configuration configuration = new Configuration(); |  | ||||||
|         configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); |  | ||||||
|         configuration.set(MIN_WORD_FREQUENCY,"1"); |  | ||||||
|         configuration.set(STOP_WORDS,""); |  | ||||||
|         configuration.set(TfidfVectorizer.SMOOTH_IDF, "false"); |  | ||||||
|         configuration.set(PREPROCESSOR, LowerCasePreProcessor.class.getName()); |  | ||||||
| 
 |  | ||||||
|         tfidfVectorizer.initialize(configuration); |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> input = new ArrayList<>(); |  | ||||||
|         //input.add(Arrays.asList(Arrays.<Writable>asList(new Text(corpus[0])),Arrays.<Writable>asList(new Text(corpus[1])))); |  | ||||||
|         List<List<Writable>> seq = new ArrayList<>(); |  | ||||||
|         for(String s : corpus){ |  | ||||||
|             seq.add(Collections.<Writable>singletonList(new Text(s))); |  | ||||||
|         } |  | ||||||
|         input.add(seq); |  | ||||||
| 
 |  | ||||||
|         CollectionRecordReader crr = new CollectionRecordReader(seq); |  | ||||||
|         INDArray arr = tfidfVectorizer.fitTransform(crr); |  | ||||||
| 
 |  | ||||||
|         //System.out.println(arr); |  | ||||||
|         assertArrayEquals(new long[]{4, 9}, arr.shape()); |  | ||||||
| 
 |  | ||||||
|         List<String> pyVocab = Arrays.asList("and", "document", "first", "is", "one", "second", "the", "third", "this"); |  | ||||||
|         List<Triple<Integer,Integer,Double>> l = new ArrayList<>(); |  | ||||||
|         l.add(new Triple<>(0, 8, 1.0)); |  | ||||||
|         l.add(new Triple<>(0, 6, 1.0)); |  | ||||||
|         l.add(new Triple<>(0, 3, 1.0)); |  | ||||||
|         l.add(new Triple<>(0, 2, 1.6931471805599454)); |  | ||||||
|         l.add(new Triple<>(0, 1, 1.2876820724517808)); |  | ||||||
|         l.add(new Triple<>(1, 8, 1.0)); |  | ||||||
|         l.add(new Triple<>(1, 6, 1.0)); |  | ||||||
|         l.add(new Triple<>(1, 5, 2.386294361119891)); |  | ||||||
|         l.add(new Triple<>(1, 3, 1.0)); |  | ||||||
|         l.add(new Triple<>(1, 1, 2.5753641449035616)); |  | ||||||
|         l.add(new Triple<>(2, 8, 1.0)); |  | ||||||
|         l.add(new Triple<>(2, 7, 2.386294361119891)); |  | ||||||
|         l.add(new Triple<>(2, 6, 1.0)); |  | ||||||
|         l.add(new Triple<>(2, 4, 2.386294361119891)); |  | ||||||
|         l.add(new Triple<>(2, 3, 1.0)); |  | ||||||
|         l.add(new Triple<>(2, 0, 2.386294361119891)); |  | ||||||
|         l.add(new Triple<>(3, 8, 1.0)); |  | ||||||
|         l.add(new Triple<>(3, 6, 1.0)); |  | ||||||
|         l.add(new Triple<>(3, 3, 1.0)); |  | ||||||
|         l.add(new Triple<>(3, 2, 1.6931471805599454)); |  | ||||||
|         l.add(new Triple<>(3, 1, 1.2876820724517808)); |  | ||||||
| 
 |  | ||||||
|         INDArray exp = Nd4j.create(DataType.FLOAT, 4, 9); |  | ||||||
|         for(Triple<Integer,Integer,Double> t : l){ |  | ||||||
|             //Work out work index, accounting for different vocab/word orders: |  | ||||||
|             int wIdx = tfidfVectorizer.getCache().wordIndex(pyVocab.get(t.getSecond())); |  | ||||||
|             exp.putScalar(t.getFirst(), wIdx, t.getThird()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         assertEquals(exp, arr); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         Map<String,Double> idfWeights = new HashMap<>(); |  | ||||||
|         idfWeights.put("and", 2.386294361119891); |  | ||||||
|         idfWeights.put("document", 1.2876820724517808); |  | ||||||
|         idfWeights.put("first", 1.6931471805599454); |  | ||||||
|         idfWeights.put("is", 1.0); |  | ||||||
|         idfWeights.put("one", 2.386294361119891); |  | ||||||
|         idfWeights.put("second", 2.386294361119891); |  | ||||||
|         idfWeights.put("the", 1.0); |  | ||||||
|         idfWeights.put("third", 2.386294361119891); |  | ||||||
|         idfWeights.put("this", 1.0); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         List<String> vocab = new ArrayList<>(9);    //Arrays.asList("is","nice","strange","this","very"); |  | ||||||
|         for( int i=0; i<9; i++ ){ |  | ||||||
|             vocab.add(tfidfVectorizer.getCache().wordAt(i)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         String inputColumnName = "input"; |  | ||||||
|         String outputColumnName = "output"; |  | ||||||
|         Map<String,Integer> wordIndexMap = new HashMap<>(); |  | ||||||
|         for(int i = 0; i < vocab.size(); i++) { |  | ||||||
|             wordIndexMap.put(vocab.get(i),i); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         TokenizerBagOfWordsTermSequenceIndexTransform transform = new TokenizerBagOfWordsTermSequenceIndexTransform( |  | ||||||
|                 inputColumnName, |  | ||||||
|                 outputColumnName, |  | ||||||
|                 wordIndexMap, |  | ||||||
|                 idfWeights, |  | ||||||
|                 false, |  | ||||||
|                 null, LowerCasePreProcessor.class.getName()); |  | ||||||
| 
 |  | ||||||
|         SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder(); |  | ||||||
|         sequenceSchemaBuilder.addColumnString("input"); |  | ||||||
|         SequenceSchema schema = sequenceSchemaBuilder.build(); |  | ||||||
|         assertEquals("input",schema.getName(0)); |  | ||||||
| 
 |  | ||||||
|         TransformProcess transformProcess = new TransformProcess.Builder(schema) |  | ||||||
|                 .transform(transform) |  | ||||||
|                 .build(); |  | ||||||
| 
 |  | ||||||
|         List<List<List<Writable>>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); |  | ||||||
| 
 |  | ||||||
|         INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); |  | ||||||
|         INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); |  | ||||||
| 
 |  | ||||||
|         assertEquals(exp.getRow(0, true), arr0); |  | ||||||
|         assertEquals(exp.getRow(1, true), arr1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,53 +0,0 @@ | |||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <configuration> |  | ||||||
| 
 |  | ||||||
|     <appender name="FILE" class="ch.qos.logback.core.FileAppender"> |  | ||||||
|         <file>logs/application.log</file> |  | ||||||
|         <encoder> |  | ||||||
|             <pattern>%date - [%level] - from %logger in %thread |  | ||||||
|                 %n%message%n%xException%n</pattern> |  | ||||||
|         </encoder> |  | ||||||
|     </appender> |  | ||||||
| 
 |  | ||||||
|     <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> |  | ||||||
|         <encoder> |  | ||||||
|             <pattern> %logger{15} - %message%n%xException{5} |  | ||||||
|             </pattern> |  | ||||||
|         </encoder> |  | ||||||
|     </appender> |  | ||||||
| 
 |  | ||||||
|     <logger name="org.apache.catalina.core" level="DEBUG" /> |  | ||||||
|     <logger name="org.springframework" level="DEBUG" /> |  | ||||||
|     <logger name="org.datavec" level="DEBUG" /> |  | ||||||
|     <logger name="org.nd4j" level="INFO" /> |  | ||||||
|     <logger name="opennlp.uima.util" level="OFF" /> |  | ||||||
|     <logger name="org.apache.uima" level="OFF" /> |  | ||||||
|     <logger name="org.cleartk" level="OFF" /> |  | ||||||
|     <logger name="org.apache.spark" level="WARN" /> |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     <root level="ERROR"> |  | ||||||
|         <appender-ref ref="STDOUT" /> |  | ||||||
|         <appender-ref ref="FILE" /> |  | ||||||
|     </root> |  | ||||||
| 
 |  | ||||||
| </configuration> |  | ||||||
| @ -1,56 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-data</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-geo</artifactId> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-api</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.maxmind.geoip2</groupId> |  | ||||||
|             <artifactId>geoip2</artifactId> |  | ||||||
|             <version>${geoip2.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,25 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.geo; |  | ||||||
| 
 |  | ||||||
| public enum LocationType { |  | ||||||
|     CITY, CITY_ID, CONTINENT, CONTINENT_ID, COUNTRY, COUNTRY_ID, COORDINATES, POSTAL_CODE, SUBDIVISIONS, SUBDIVISIONS_ID |  | ||||||
| } |  | ||||||
| @ -1,194 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.reduce.geo; |  | ||||||
| 
 |  | ||||||
| import lombok.Getter; |  | ||||||
| import org.datavec.api.transform.ReduceOp; |  | ||||||
| import org.datavec.api.transform.metadata.ColumnMetaData; |  | ||||||
| import org.datavec.api.transform.metadata.StringMetaData; |  | ||||||
| import org.datavec.api.transform.ops.IAggregableReduceOp; |  | ||||||
| import org.datavec.api.transform.reduce.AggregableColumnReduction; |  | ||||||
| import org.datavec.api.transform.reduce.AggregableReductionUtils; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.api.writable.DoubleWritable; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.common.function.Supplier; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class CoordinatesReduction implements AggregableColumnReduction { |  | ||||||
|     public static final String DEFAULT_COLUMN_NAME = "CoordinatesReduction"; |  | ||||||
| 
 |  | ||||||
|     public final static String DEFAULT_DELIMITER = ":"; |  | ||||||
|     protected String delimiter = DEFAULT_DELIMITER; |  | ||||||
| 
 |  | ||||||
|     private final List<String> columnNamesPostReduce; |  | ||||||
| 
 |  | ||||||
|     private final Supplier<IAggregableReduceOp<Writable, List<Writable>>> multiOp(final List<ReduceOp> ops) { |  | ||||||
|         return new Supplier<IAggregableReduceOp<Writable, List<Writable>>>() { |  | ||||||
|             @Override |  | ||||||
|             public IAggregableReduceOp<Writable, List<Writable>> get() { |  | ||||||
|                 return AggregableReductionUtils.reduceDoubleColumn(ops, false, null); |  | ||||||
|             } |  | ||||||
|         }; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public CoordinatesReduction(String columnNamePostReduce, ReduceOp op) { |  | ||||||
|         this(columnNamePostReduce, op, DEFAULT_DELIMITER); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public CoordinatesReduction(List<String> columnNamePostReduce, List<ReduceOp> op) { |  | ||||||
|         this(columnNamePostReduce, op, DEFAULT_DELIMITER); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public CoordinatesReduction(String columnNamePostReduce, ReduceOp op, String delimiter) { |  | ||||||
|         this(Collections.singletonList(columnNamePostReduce), Collections.singletonList(op), delimiter); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public CoordinatesReduction(List<String> columnNamesPostReduce, List<ReduceOp> ops, String delimiter) { |  | ||||||
|         this.columnNamesPostReduce = columnNamesPostReduce; |  | ||||||
|         this.reducer = new CoordinateAggregableReduceOp(ops.size(), multiOp(ops), delimiter); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<String> getColumnsOutputName(String columnInputName) { |  | ||||||
|         return columnNamesPostReduce; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public List<ColumnMetaData> getColumnOutputMetaData(List<String> newColumnName, ColumnMetaData columnInputMeta) { |  | ||||||
|         List<ColumnMetaData> res = new ArrayList<>(newColumnName.size()); |  | ||||||
|         for (String cn : newColumnName) |  | ||||||
|             res.add(new StringMetaData((cn))); |  | ||||||
|         return res; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Schema transform(Schema inputSchema) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setInputSchema(Schema inputSchema) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Schema getInputSchema() { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String outputColumnName() { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String[] outputColumnNames() { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String[] columnNames() { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String columnName() { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private IAggregableReduceOp<Writable, List<Writable>> reducer; |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public IAggregableReduceOp<Writable, List<Writable>> reduceOp() { |  | ||||||
|         return reducer; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static class CoordinateAggregableReduceOp implements IAggregableReduceOp<Writable, List<Writable>> { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         private int nOps; |  | ||||||
|         private Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOpValue; |  | ||||||
|         @Getter |  | ||||||
|         private ArrayList<IAggregableReduceOp<Writable, List<Writable>>> perCoordinateOps; // of size coords() |  | ||||||
|         private String delimiter; |  | ||||||
| 
 |  | ||||||
|         public CoordinateAggregableReduceOp(int n, Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOp, |  | ||||||
|                         String delim) { |  | ||||||
|             this.nOps = n; |  | ||||||
|             this.perCoordinateOps = new ArrayList<>(); |  | ||||||
|             this.initialOpValue = initialOp; |  | ||||||
|             this.delimiter = delim; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public <W extends IAggregableReduceOp<Writable, List<Writable>>> void combine(W accu) { |  | ||||||
|             if (accu instanceof CoordinateAggregableReduceOp) { |  | ||||||
|                 CoordinateAggregableReduceOp accumulator = (CoordinateAggregableReduceOp) accu; |  | ||||||
|                 for (int i = 0; i < Math.min(perCoordinateOps.size(), accumulator.getPerCoordinateOps().size()); i++) { |  | ||||||
|                     perCoordinateOps.get(i).combine(accumulator.getPerCoordinateOps().get(i)); |  | ||||||
|                 } // the rest is assumed identical |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public void accept(Writable writable) { |  | ||||||
|             String[] coordinates = writable.toString().split(delimiter); |  | ||||||
|             for (int i = 0; i < coordinates.length; i++) { |  | ||||||
|                 String coordinate = coordinates[i]; |  | ||||||
|                 while (perCoordinateOps.size() < i + 1) { |  | ||||||
|                     perCoordinateOps.add(initialOpValue.get()); |  | ||||||
|                 } |  | ||||||
|                 perCoordinateOps.get(i).accept(new DoubleWritable(Double.parseDouble(coordinate))); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public List<Writable> get() { |  | ||||||
|             List<StringBuilder> res = new ArrayList<>(nOps); |  | ||||||
|             for (int i = 0; i < nOps; i++) { |  | ||||||
|                 res.add(new StringBuilder()); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             for (int i = 0; i < perCoordinateOps.size(); i++) { |  | ||||||
|                 List<Writable> resThisCoord = perCoordinateOps.get(i).get(); |  | ||||||
|                 for (int j = 0; j < nOps; j++) { |  | ||||||
|                     res.get(j).append(resThisCoord.get(j).toString()); |  | ||||||
|                     if (i < perCoordinateOps.size() - 1) { |  | ||||||
|                         res.get(j).append(delimiter); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             List<Writable> finalRes = new ArrayList<>(nOps); |  | ||||||
|             for (StringBuilder sb : res) { |  | ||||||
|                 finalRes.add(new Text(sb.toString())); |  | ||||||
|             } |  | ||||||
|             return finalRes; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,117 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.transform.geo; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.MathOp; |  | ||||||
| import org.datavec.api.transform.metadata.ColumnMetaData; |  | ||||||
| import org.datavec.api.transform.metadata.DoubleMetaData; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; |  | ||||||
| import org.datavec.api.writable.DoubleWritable; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class CoordinatesDistanceTransform extends BaseColumnsMathOpTransform { |  | ||||||
| 
 |  | ||||||
|     public final static String DEFAULT_DELIMITER = ":"; |  | ||||||
|     protected String delimiter = DEFAULT_DELIMITER; |  | ||||||
| 
 |  | ||||||
|     public CoordinatesDistanceTransform(String newColumnName, String firstColumn, String secondColumn, |  | ||||||
|                     String stdevColumn) { |  | ||||||
|         this(newColumnName, firstColumn, secondColumn, stdevColumn, DEFAULT_DELIMITER); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public CoordinatesDistanceTransform(@JsonProperty("newColumnName") String newColumnName, |  | ||||||
|                     @JsonProperty("firstColumn") String firstColumn, @JsonProperty("secondColumn") String secondColumn, |  | ||||||
|                     @JsonProperty("stdevColumn") String stdevColumn, @JsonProperty("delimiter") String delimiter) { |  | ||||||
|         super(newColumnName, MathOp.Add /* dummy op */, |  | ||||||
|                         stdevColumn != null ? new String[] {firstColumn, secondColumn, stdevColumn} |  | ||||||
|                                         : new String[] {firstColumn, secondColumn}); |  | ||||||
|         this.delimiter = delimiter; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected ColumnMetaData derivedColumnMetaData(String newColumnName, Schema inputSchema) { |  | ||||||
|         return new DoubleMetaData(newColumnName); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Writable doOp(Writable... input) { |  | ||||||
|         String[] first = input[0].toString().split(delimiter); |  | ||||||
|         String[] second = input[1].toString().split(delimiter); |  | ||||||
|         String[] stdev = columns.length > 2 ? input[2].toString().split(delimiter) : null; |  | ||||||
| 
 |  | ||||||
|         double dist = 0; |  | ||||||
|         for (int i = 0; i < first.length; i++) { |  | ||||||
|             double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]); |  | ||||||
|             double s = stdev != null ? Double.parseDouble(stdev[i]) : 1; |  | ||||||
|             dist += (d * d) / (s * s); |  | ||||||
|         } |  | ||||||
|         return new DoubleWritable(Math.sqrt(dist)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return "CoordinatesDistanceTransform(newColumnName=\"" + newColumnName + "\",columns=" |  | ||||||
|                         + Arrays.toString(columns) + ",delimiter=" + delimiter + ")"; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Transform an object |  | ||||||
|      * in to another object |  | ||||||
|      * |  | ||||||
|      * @param input the record to transform |  | ||||||
|      * @return the transformed writable |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Object map(Object input) { |  | ||||||
|         List row = (List) input; |  | ||||||
|         String[] first = row.get(0).toString().split(delimiter); |  | ||||||
|         String[] second = row.get(1).toString().split(delimiter); |  | ||||||
|         String[] stdev = columns.length > 2 ? row.get(2).toString().split(delimiter) : null; |  | ||||||
| 
 |  | ||||||
|         double dist = 0; |  | ||||||
|         for (int i = 0; i < first.length; i++) { |  | ||||||
|             double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]); |  | ||||||
|             double s = stdev != null ? Double.parseDouble(stdev[i]) : 1; |  | ||||||
|             dist += (d * d) / (s * s); |  | ||||||
|         } |  | ||||||
|         return Math.sqrt(dist); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Transform a sequence |  | ||||||
|      * |  | ||||||
|      * @param sequence |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Object mapSequence(Object sequence) { |  | ||||||
|         List<List> seq = (List<List>) sequence; |  | ||||||
|         List<Double> ret = new ArrayList<>(); |  | ||||||
|         for (Object step : seq) |  | ||||||
|             ret.add((Double) map(step)); |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,70 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.transform.geo; |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.nd4j.common.base.Preconditions; |  | ||||||
| import org.nd4j.common.util.ArchiveUtils; |  | ||||||
| import org.slf4j.Logger; |  | ||||||
| import org.slf4j.LoggerFactory; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.net.URL; |  | ||||||
| 
 |  | ||||||
| public class GeoIPFetcher { |  | ||||||
|     protected static final Logger log = LoggerFactory.getLogger(GeoIPFetcher.class); |  | ||||||
| 
 |  | ||||||
|     /** Default directory for <a href="http://dev.maxmind.com/geoip/geoipupdate/">http://dev.maxmind.com/geoip/geoipupdate/</a> */ |  | ||||||
|     public static final String GEOIP_DIR = "/usr/local/share/GeoIP/"; |  | ||||||
|     public static final String GEOIP_DIR2 = System.getProperty("user.home") + "/.datavec-geoip"; |  | ||||||
| 
 |  | ||||||
|     public static final String CITY_DB = "GeoIP2-City.mmdb"; |  | ||||||
|     public static final String CITY_LITE_DB = "GeoLite2-City.mmdb"; |  | ||||||
| 
 |  | ||||||
|     public static final String CITY_LITE_URL = |  | ||||||
|                     "http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz"; |  | ||||||
| 
 |  | ||||||
|     public static synchronized File fetchCityDB() throws IOException { |  | ||||||
|         File cityFile = new File(GEOIP_DIR, CITY_DB); |  | ||||||
|         if (cityFile.isFile()) { |  | ||||||
|             return cityFile; |  | ||||||
|         } |  | ||||||
|         cityFile = new File(GEOIP_DIR, CITY_LITE_DB); |  | ||||||
|         if (cityFile.isFile()) { |  | ||||||
|             return cityFile; |  | ||||||
|         } |  | ||||||
|         cityFile = new File(GEOIP_DIR2, CITY_LITE_DB); |  | ||||||
|         if (cityFile.isFile()) { |  | ||||||
|             return cityFile; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         log.info("Downloading GeoLite2 City database..."); |  | ||||||
|         File archive = new File(GEOIP_DIR2, CITY_LITE_DB + ".gz"); |  | ||||||
|         File dir = new File(GEOIP_DIR2); |  | ||||||
|         dir.mkdirs(); |  | ||||||
|         FileUtils.copyURLToFile(new URL(CITY_LITE_URL), archive); |  | ||||||
|         ArchiveUtils.unzipFileTo(archive.getAbsolutePath(), dir.getAbsolutePath()); |  | ||||||
|         Preconditions.checkState(cityFile.isFile(), "Error extracting files: expected city file does not exist after extraction"); |  | ||||||
| 
 |  | ||||||
|         return cityFile; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,43 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.transform.geo; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.geo.LocationType; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| public class IPAddressToCoordinatesTransform extends IPAddressToLocationTransform { |  | ||||||
| 
 |  | ||||||
|     public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName) throws IOException { |  | ||||||
|         this(columnName, DEFAULT_DELIMITER); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName, |  | ||||||
|                     @JsonProperty("delimiter") String delimiter) throws IOException { |  | ||||||
|         super(columnName, LocationType.COORDINATES, delimiter); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return "IPAddressToCoordinatesTransform"; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,184 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.transform.geo; |  | ||||||
| 
 |  | ||||||
| import com.maxmind.geoip2.DatabaseReader; |  | ||||||
| import com.maxmind.geoip2.exception.GeoIp2Exception; |  | ||||||
| import com.maxmind.geoip2.model.CityResponse; |  | ||||||
| import com.maxmind.geoip2.record.Location; |  | ||||||
| import com.maxmind.geoip2.record.Subdivision; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.datavec.api.transform.geo.LocationType; |  | ||||||
| import org.datavec.api.transform.metadata.ColumnMetaData; |  | ||||||
| import org.datavec.api.transform.metadata.StringMetaData; |  | ||||||
| import org.datavec.api.transform.transform.BaseColumnTransform; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.shade.jackson.annotation.JsonProperty; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.io.ObjectInputStream; |  | ||||||
| import java.io.ObjectOutputStream; |  | ||||||
| import java.net.InetAddress; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class IPAddressToLocationTransform extends BaseColumnTransform { |  | ||||||
|     /** |  | ||||||
|      * Name of the system property to use when configuring the GeoIP database file.<br> |  | ||||||
|      * Most users don't need to set this - typically used for testing purposes.<br> |  | ||||||
|      * Set with the full local path, like: "C:/datavec-geo/GeoIP2-City-Test.mmdb" |  | ||||||
|      */ |  | ||||||
|     public static final String GEOIP_FILE_PROPERTY = "org.datavec.geoip.file"; |  | ||||||
| 
 |  | ||||||
|     private static File database; |  | ||||||
|     private static DatabaseReader reader; |  | ||||||
| 
 |  | ||||||
|     public final static String DEFAULT_DELIMITER = ":"; |  | ||||||
|     protected String delimiter = DEFAULT_DELIMITER; |  | ||||||
|     protected LocationType locationType; |  | ||||||
| 
 |  | ||||||
|     private static synchronized void init() throws IOException { |  | ||||||
|         // A File object pointing to your GeoIP2 or GeoLite2 database: |  | ||||||
|         // http://dev.maxmind.com/geoip/geoip2/geolite2/ |  | ||||||
|         if (database == null) { |  | ||||||
|             String s = System.getProperty(GEOIP_FILE_PROPERTY); |  | ||||||
|             if(s != null && !s.isEmpty()){ |  | ||||||
|                 //Use user-specified GEOIP file - mainly for testing purposes |  | ||||||
|                 File f = new File(s); |  | ||||||
|                 if(f.exists() && f.isFile()){ |  | ||||||
|                     database = f; |  | ||||||
|                 } else { |  | ||||||
|                     log.warn("GeoIP file (system property {}) is set to \"{}\" but this is not a valid file, using default database", GEOIP_FILE_PROPERTY, s); |  | ||||||
|                     database = GeoIPFetcher.fetchCityDB(); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 database = GeoIPFetcher.fetchCityDB(); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // This creates the DatabaseReader object, which should be reused across lookups. |  | ||||||
|         if (reader == null) { |  | ||||||
|             reader = new DatabaseReader.Builder(database).build(); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public IPAddressToLocationTransform(String columnName) throws IOException { |  | ||||||
|         this(columnName, LocationType.CITY); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public IPAddressToLocationTransform(String columnName, LocationType locationType) throws IOException { |  | ||||||
|         this(columnName, locationType, DEFAULT_DELIMITER); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public IPAddressToLocationTransform(@JsonProperty("columnName") String columnName, |  | ||||||
|                     @JsonProperty("delimiter") LocationType locationType, @JsonProperty("delimiter") String delimiter) |  | ||||||
|                     throws IOException { |  | ||||||
|         super(columnName); |  | ||||||
|         this.delimiter = delimiter; |  | ||||||
|         this.locationType = locationType; |  | ||||||
|         init(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { |  | ||||||
|         return new StringMetaData(newName); //Output after transform: String (Text) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Writable map(Writable columnWritable) { |  | ||||||
|         try { |  | ||||||
|             InetAddress ipAddress = InetAddress.getByName(columnWritable.toString()); |  | ||||||
|             CityResponse response = reader.city(ipAddress); |  | ||||||
|             String text = ""; |  | ||||||
|             switch (locationType) { |  | ||||||
|                 case CITY: |  | ||||||
|                     text = response.getCity().getName(); |  | ||||||
|                     break; |  | ||||||
|                 case CITY_ID: |  | ||||||
|                     text = response.getCity().getGeoNameId().toString(); |  | ||||||
|                     break; |  | ||||||
|                 case CONTINENT: |  | ||||||
|                     text = response.getContinent().getName(); |  | ||||||
|                     break; |  | ||||||
|                 case CONTINENT_ID: |  | ||||||
|                     text = response.getContinent().getGeoNameId().toString(); |  | ||||||
|                     break; |  | ||||||
|                 case COUNTRY: |  | ||||||
|                     text = response.getCountry().getName(); |  | ||||||
|                     break; |  | ||||||
|                 case COUNTRY_ID: |  | ||||||
|                     text = response.getCountry().getGeoNameId().toString(); |  | ||||||
|                     break; |  | ||||||
|                 case COORDINATES: |  | ||||||
|                     Location location = response.getLocation(); |  | ||||||
|                     text = location.getLatitude() + delimiter + location.getLongitude(); |  | ||||||
|                     break; |  | ||||||
|                 case POSTAL_CODE: |  | ||||||
|                     text = response.getPostal().getCode(); |  | ||||||
|                     break; |  | ||||||
|                 case SUBDIVISIONS: |  | ||||||
|                     for (Subdivision s : response.getSubdivisions()) { |  | ||||||
|                         if (text.length() > 0) { |  | ||||||
|                             text += delimiter; |  | ||||||
|                         } |  | ||||||
|                         text += s.getName(); |  | ||||||
|                     } |  | ||||||
|                     break; |  | ||||||
|                 case SUBDIVISIONS_ID: |  | ||||||
|                     for (Subdivision s : response.getSubdivisions()) { |  | ||||||
|                         if (text.length() > 0) { |  | ||||||
|                             text += delimiter; |  | ||||||
|                         } |  | ||||||
|                         text += s.getGeoNameId().toString(); |  | ||||||
|                     } |  | ||||||
|                     break; |  | ||||||
|                 default: |  | ||||||
|                     assert false; |  | ||||||
|             } |  | ||||||
|             if(text == null) |  | ||||||
|                 text = ""; |  | ||||||
|             return new Text(text); |  | ||||||
|         } catch (GeoIp2Exception | IOException e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return "IPAddressToLocationTransform"; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     //Custom serialization methods, because GeoIP2 doesn't allow DatabaseReader objects to be serialized :( |  | ||||||
|     private void writeObject(ObjectOutputStream out) throws IOException { |  | ||||||
|         out.defaultWriteObject(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { |  | ||||||
|         in.defaultReadObject(); |  | ||||||
|         init(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Object map(Object input) { |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,46 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| package org.datavec.api.transform; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Set<Class<?>> getExclusions() { |  | ||||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) |  | ||||||
| 	    return new HashSet<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
| 	protected String getPackageName() { |  | ||||||
|     	return "org.datavec.api.transform"; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Class<?> getBaseClass() { |  | ||||||
|     	return BaseND4JTest.class; |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,80 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.reduce; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.ColumnType; |  | ||||||
| import org.datavec.api.transform.ReduceOp; |  | ||||||
| import org.datavec.api.transform.ops.IAggregableReduceOp; |  | ||||||
| import org.datavec.api.transform.reduce.geo.CoordinatesReduction; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.junit.Test; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author saudet |  | ||||||
|  */ |  | ||||||
| public class TestGeoReduction { |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testCustomReductions() { |  | ||||||
| 
 |  | ||||||
|         List<List<Writable>> inputs = new ArrayList<>(); |  | ||||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1#5"))); |  | ||||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2#6"))); |  | ||||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3#7"))); |  | ||||||
|         inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4#8"))); |  | ||||||
| 
 |  | ||||||
|         List<Writable> expected = Arrays.asList((Writable) new Text("someKey"), new Text("10.0#26.0")); |  | ||||||
| 
 |  | ||||||
|         Schema schema = new Schema.Builder().addColumnString("key").addColumnString("coord").build(); |  | ||||||
| 
 |  | ||||||
|         Reducer reducer = new Reducer.Builder(ReduceOp.Count).keyColumns("key") |  | ||||||
|                         .customReduction("coord", new CoordinatesReduction("coordSum", ReduceOp.Sum, "#")).build(); |  | ||||||
| 
 |  | ||||||
|         reducer.setInputSchema(schema); |  | ||||||
| 
 |  | ||||||
|         IAggregableReduceOp<List<Writable>, List<Writable>> aggregableReduceOp = reducer.aggregableReducer(); |  | ||||||
|         for (List<Writable> l : inputs) |  | ||||||
|             aggregableReduceOp.accept(l); |  | ||||||
|         List<Writable> out = aggregableReduceOp.get(); |  | ||||||
| 
 |  | ||||||
|         assertEquals(2, out.size()); |  | ||||||
|         assertEquals(expected, out); |  | ||||||
| 
 |  | ||||||
|         //Check schema: |  | ||||||
|         String[] expNames = new String[] {"key", "coordSum"}; |  | ||||||
|         ColumnType[] expTypes = new ColumnType[] {ColumnType.String, ColumnType.String}; |  | ||||||
|         Schema outSchema = reducer.transform(schema); |  | ||||||
| 
 |  | ||||||
|         assertEquals(2, outSchema.numColumns()); |  | ||||||
|         for (int i = 0; i < 2; i++) { |  | ||||||
|             assertEquals(expNames[i], outSchema.getName(i)); |  | ||||||
|             assertEquals(expTypes[i], outSchema.getType(i)); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,153 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.api.transform.transform; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.ColumnType; |  | ||||||
| import org.datavec.api.transform.Transform; |  | ||||||
| import org.datavec.api.transform.geo.LocationType; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform; |  | ||||||
| import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform; |  | ||||||
| import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform; |  | ||||||
| import org.datavec.api.writable.DoubleWritable; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.junit.AfterClass; |  | ||||||
| import org.junit.BeforeClass; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.common.io.ClassPathResource; |  | ||||||
| 
 |  | ||||||
| import java.io.*; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author saudet |  | ||||||
|  */ |  | ||||||
| public class TestGeoTransforms { |  | ||||||
| 
 |  | ||||||
|     @BeforeClass |  | ||||||
|     public static void beforeClass() throws Exception { |  | ||||||
|         //Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing |  | ||||||
|         File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile(); |  | ||||||
|         System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @AfterClass |  | ||||||
|     public static void afterClass(){ |  | ||||||
|         System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, ""); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testCoordinatesDistanceTransform() throws Exception { |  | ||||||
|         Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev") |  | ||||||
|                         .build(); |  | ||||||
| 
 |  | ||||||
|         Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|"); |  | ||||||
|         transform.setInputSchema(schema); |  | ||||||
| 
 |  | ||||||
|         Schema out = transform.transform(schema); |  | ||||||
|         assertEquals(4, out.numColumns()); |  | ||||||
|         assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames()); |  | ||||||
|         assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double), |  | ||||||
|                         out.getColumnTypes()); |  | ||||||
| 
 |  | ||||||
|         assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)), |  | ||||||
|                         transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10")))); |  | ||||||
|         assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"), |  | ||||||
|                         new DoubleWritable(Math.sqrt(160))), |  | ||||||
|                         transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), |  | ||||||
|                                         new Text("10|5")))); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testIPAddressToCoordinatesTransform() throws Exception { |  | ||||||
|         Schema schema = new Schema.Builder().addColumnString("column").build(); |  | ||||||
| 
 |  | ||||||
|         Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER"); |  | ||||||
|         transform.setInputSchema(schema); |  | ||||||
| 
 |  | ||||||
|         Schema out = transform.transform(schema); |  | ||||||
| 
 |  | ||||||
|         assertEquals(1, out.getColumnMetaData().size()); |  | ||||||
|         assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); |  | ||||||
| 
 |  | ||||||
|         String in = "81.2.69.160"; |  | ||||||
|         double latitude = 51.5142; |  | ||||||
|         double longitude = -0.0931; |  | ||||||
| 
 |  | ||||||
|         List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in))); |  | ||||||
|         assertEquals(1, writables.size()); |  | ||||||
|         String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); |  | ||||||
|         assertEquals(2, coordinates.length); |  | ||||||
|         assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1); |  | ||||||
|         assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1); |  | ||||||
| 
 |  | ||||||
|         //Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/ |  | ||||||
|         ByteArrayOutputStream baos = new ByteArrayOutputStream(); |  | ||||||
|         ObjectOutputStream oos = new ObjectOutputStream(baos); |  | ||||||
|         oos.writeObject(transform); |  | ||||||
| 
 |  | ||||||
|         byte[] bytes = baos.toByteArray(); |  | ||||||
| 
 |  | ||||||
|         ByteArrayInputStream bais = new ByteArrayInputStream(bytes); |  | ||||||
|         ObjectInputStream ois = new ObjectInputStream(bais); |  | ||||||
| 
 |  | ||||||
|         Transform deserialized = (Transform) ois.readObject(); |  | ||||||
|         writables = deserialized.map(Collections.singletonList((Writable) new Text(in))); |  | ||||||
|         assertEquals(1, writables.size()); |  | ||||||
|         coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); |  | ||||||
|         //System.out.println(Arrays.toString(coordinates)); |  | ||||||
|         assertEquals(2, coordinates.length); |  | ||||||
|         assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1); |  | ||||||
|         assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testIPAddressToLocationTransform() throws Exception { |  | ||||||
|         Schema schema = new Schema.Builder().addColumnString("column").build(); |  | ||||||
|         LocationType[] locationTypes = LocationType.values(); |  | ||||||
|         String in = "81.2.69.160"; |  | ||||||
|         String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167", |  | ||||||
|                         "51.5142:-0.0931", "", "England", "6269131"};    //Note: no postcode in this test DB for this record |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < locationTypes.length; i++) { |  | ||||||
|             LocationType locationType = locationTypes[i]; |  | ||||||
|             String location = locations[i]; |  | ||||||
| 
 |  | ||||||
|             Transform transform = new IPAddressToLocationTransform("column", locationType); |  | ||||||
|             transform.setInputSchema(schema); |  | ||||||
| 
 |  | ||||||
|             Schema out = transform.transform(schema); |  | ||||||
| 
 |  | ||||||
|             assertEquals(1, out.getColumnMetaData().size()); |  | ||||||
|             assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); |  | ||||||
| 
 |  | ||||||
|             List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in))); |  | ||||||
|             assertEquals(1, writables.size()); |  | ||||||
|             assertEquals(location, writables.get(0).toString()); |  | ||||||
|             //System.out.println(location); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -37,11 +37,7 @@ | |||||||
|     <name>datavec-data</name> |     <name>datavec-data</name> | ||||||
| 
 | 
 | ||||||
|     <modules> |     <modules> | ||||||
|         <module>datavec-data-audio</module> |  | ||||||
|         <module>datavec-data-codec</module> |  | ||||||
|         <module>datavec-data-image</module> |         <module>datavec-data-image</module> | ||||||
|         <module>datavec-data-nlp</module> |  | ||||||
|         <module>datavec-geo</module> |  | ||||||
|     </modules> |     </modules> | ||||||
| 
 | 
 | ||||||
|     <dependencyManagement> |     <dependencyManagement> | ||||||
|  | |||||||
| @ -1,64 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-spark-inference-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-spark-inference-client</artifactId> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-spark-inference-client</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-spark-inference-server_2.11</artifactId> |  | ||||||
|             <version>1.0.0-SNAPSHOT</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-spark-inference-model</artifactId> |  | ||||||
|             <version>${project.parent.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.mashape.unirest</groupId> |  | ||||||
|             <artifactId>unirest-java</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,292 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.client; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import com.mashape.unirest.http.ObjectMapper; |  | ||||||
| import com.mashape.unirest.http.Unirest; |  | ||||||
| import com.mashape.unirest.http.exceptions.UnirestException; |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.model.model.*; |  | ||||||
| import org.datavec.spark.inference.model.service.DataVecTransformService; |  | ||||||
| import org.nd4j.shade.jackson.core.JsonProcessingException; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| @AllArgsConstructor |  | ||||||
| @Slf4j |  | ||||||
| public class DataVecTransformClient implements DataVecTransformService { |  | ||||||
|     private String url; |  | ||||||
| 
 |  | ||||||
|     static { |  | ||||||
|         // Only one time |  | ||||||
|         Unirest.setObjectMapper(new ObjectMapper() { |  | ||||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = |  | ||||||
|                     new org.nd4j.shade.jackson.databind.ObjectMapper(); |  | ||||||
| 
 |  | ||||||
|             public <T> T readValue(String value, Class<T> valueType) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.readValue(value, valueType); |  | ||||||
|                 } catch (IOException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             public String writeValue(Object value) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.writeValueAsString(value); |  | ||||||
|                 } catch (JsonProcessingException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transformProcess |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { |  | ||||||
|         try { |  | ||||||
|             String s = transformProcess.toJson(); |  | ||||||
|             Unirest.post(url + "/transformprocess").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json").body(s).asJson(); |  | ||||||
| 
 |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in setCSVTransformProcess()", e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public TransformProcess getCSVTransformProcess() { |  | ||||||
|         try { |  | ||||||
|             String s = Unirest.get(url + "/transformprocess").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json").asString().getBody(); |  | ||||||
|             return TransformProcess.fromJson(s); |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in getCSVTransformProcess()",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ImageTransformProcess getImageTransformProcess() { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { |  | ||||||
|         try { |  | ||||||
|             SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .body(transform).asObject(SingleCSVRecord.class).getBody(); |  | ||||||
|             return singleCsvRecord; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformIncremental(SingleCSVRecord)",e); |  | ||||||
|         } |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"TRUE") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(SequenceBatchCSVRecord.class) |  | ||||||
|                     .getBody(); |  | ||||||
|             return batchCSVRecord1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"FALSE") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(BatchCSVRecord.class) |  | ||||||
|                     .getBody(); |  | ||||||
|             return batchCSVRecord1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transform(BatchCSVRecord)", e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json").body(batchCSVRecord) |  | ||||||
|                     .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return batchArray1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformArray(BatchCSVRecord)",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") |  | ||||||
|                     .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return array; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return array; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequenceArrayIncremental",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return batchArray1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequenceArray",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(SequenceBatchCSVRecord.class).getBody(); |  | ||||||
|             return batchCSVRecord1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequence"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { |  | ||||||
|         try { |  | ||||||
|             SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(transform).asObject(SequenceBatchCSVRecord.class).getBody(); |  | ||||||
|             return singleCsvRecord; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequenceIncremental"); |  | ||||||
|         } |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user